Writing Unit Test for Apache Spark using Memory Streams

Knoldus Blog Audio
Reading Time: 2 minutes

In this post, we are going to look into how we can leverage apache spark’s memory streams for Unit testing

apache spark memory streams writing unit testing

What is it ?

Apache spark’s memory streams is a concrete streaming source of memory data source that supports reading in Micro-Batch Stream Processing.

Lets jump into it

We will be using a memory stream writing some test data in memory as a stream. We will then apply some transformation and will be comparing the results.

case class Person(id:Int, salary:Int)

val inputData = Seq(Person(1, 4),Person(2,8),Person(3,7),Person(4,89))
val expectedData = Seq(Person(2,8),Person(4,89))

So, from the above inputData and expectedData, you may guess right, that we are simply going to do a filter according to even number of id from Person.

But before that lets define how we want to store the data in memory. For testing purposes, we are going to use a mutable ListBuffer (or memory as a sink). We are also keeping it in a map just so if we want to use the same variable for different test cases.

private val data = new mutable.HashMap[String, mutable.ListBuffer[Person]]()

Also we have to implement our writer extending the spark’s inbuilt ForeachWriter which we have to pass while writing to the sink.

  class InMemoryStoreWriter[T](
                                key: String,
                                rowExtractor: T => Person) extends ForeachWriter[T] {

    override def open(partitionId: Long, version: Long): Boolean = true

    override def process(row: T): Unit = {
      InMemoryStoreWriter.addValue(key, rowExtractor(row))

    override def close(errorOrNull: Throwable): Unit = {}

So, what we are doing here is basically adding new values in our ListBuffer with new entries in the stream.

  object InMemoryStoreWriter {
    private val data = new mutable.HashMap[String, mutable.ListBuffer[Person]]()

    def addValue(key: String, value: Person): Option[ListBuffer[Person]] = {
      data.synchronized {
        val values = data.getOrElse(key, new mutable.ListBuffer[Person]())
        data.put(key, values)

    def getValues(key: String): mutable.ListBuffer[Person] = data.getOrElse(key, ListBuffer.empty)

Time for Adding UTs

So far our writer is ready. Lets add our transformation function.

  def testEvenTransform(ds:Dataset[Person]): Dataset[Person] ={
    ds.filter(_.id % 2 == 0)

Putting together

      val inputData = Seq(Person(1, 4),Person(2,8),Person(3,7),Person(4,89))
      val expectedData = Seq(Person(2,8),Person(4,89))

      val testKey = "person-test-1"
      val inputStream: MemoryStream[Person] = new MemoryStream[Person](1, sparkSession.sqlContext)
      inputStream.addData(inputData) //This is how we are adding entries into our memory 

After adding our data in the stream, we need to apply our transformation function. It can be anything (for eg: deduplication). As a result it should write the transformed data into the memory.

        .foreach {
          new InMemoryStoreWriter[Person](testKey, data => data)// Writing to memory

Now, lets get back the result stored in the memory and compare

      val transformedData = InMemoryStoreWriter.getValues(testKey)

      transformedData.sortBy(_.id) shouldBe expectedData.sortBy(_.id) 

So, this should be it where we have applied memory as a stream and tested one of our transformation function. The use case varies accordingly. It is totally up to you how you want to implemented it and what you want to test.

For reference code : https://github.com/knoldus/dedupe_spark_sample

References :