Using Windows in Spark to Avoid Joins

Table of contents
Reading Time: 3 minutes

When you think of windows in Spark you might think of Spark Streaming, but windows can be used on regular DataFrames. Window functions calculate an output value for every row of a DataFrame based on a group of rows. I have been working on optimizing some Spark code and have noticed a few places where the use of a window function eliminates the need for a join and speeds up the code. A common pattern where a window can be used to replace a join is when an aggregation is performed on a DataFrame and then the DataFrame resulting from the aggregation is joined to the original DataFrame. Let’s take a look at an example.

import util.Random
import org.apache.spark.sql.functions._

val maxX = 500000
val nrow = maxX*10
val randomList = Seq.fill(nrow)((Random.nextInt(maxX), Random.nextInt, Random.nextInt))
val df = sc.parallelize(randomList).toDF("x", "y", "z")

val startTime = System.currentTimeMillis()

val dfAgg = df.groupBy("x").agg(max("y").as("max_y"))
val dfJoined = df.join(dfAgg, "x")

val outputParq = "joinAggTest.parq"
dfJoined.write.parquet(outputParq)

val endTime = System.currentTimeMillis()

println("Aggregation and join took " + (endTime-startTime)/(1000.0) + " seconds")

First a list of nrow tuples is created. Each tuple consists of three random integers, the first of which is constrained to be less than maxX. This list is then converted to a DataFrame with three columns “x”, “y”, and “z”. Next rows are grouped by the “x” column and for each value of “x” the maximum value of “y” is found and this value is saved to the column “max_y”. The resulting DataFrame, dfAgg, is joined to the original DataFrame. Since nrow is set to be 10 times larger than maxX, each value of “x” will be repeated an average of 10 times. The resulting DataFrame is shown below.

An alternative method for doing the same thing using a window function is displayed below.

import util.Random
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val maxX = 500000
val nrow = maxX*10
val randomList = Seq.fill(nrow)((Random.nextInt(maxX), Random.nextInt, Random.nextInt))
val df = sc.parallelize(randomList).toDF("x", "y", "z")

val startTime = System.currentTimeMillis()

val window = Window.partitionBy("x")
val dfWindow = df.withColumn("max_y", max("y").over(window))

val outputParq = "WindowTest.parq"
dfWindow.write.parquet(outputParq)

val endTime = System.currentTimeMillis()

println("Window took " + (endTime-startTime)/(1000.0) + " seconds")

Here partitionBy determines how rows are grouped into frames over which the max function is applied. It is not simply distributing the rows to different partitions.

I ran these two versions of the code using different numbers of maxX and therefore nrow, on Databricks using 30 Standard_DS14_v2 workers. The plot of the time taken by the two methods as a function of the nrow is displayed below.

When running with 20,000,000 rows the window function is 1.9 times faster than using an aggregation and join and when running with 100,000,000 rows the window function is 1.4 times faster. A window function may not always be the best method and testing needs to be done when optimizing Spark code.

The DAG for aggregate and join method is displayed below.

The DAG for the window method is displayed below.

The following code confirms that the two methods give the same results

import util.Random
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val maxX = 500000
val nrow = maX*10
val randomList = Seq.fill(nrow)((Random.nextInt(maxX), Random.nextInt, Random.nextInt))
val df = sc.parallelize(randomList).toDF("x", "y", "z").persist

val startTime = System.currentTimeMillis()

val dfAgg = df.groupBy("x").agg(max("y").as("max_y"))
val dfJoined = df.join(dfAgg, "x")

println("dfJoined.count= " + dfJoined.count)

val endTime = System.currentTimeMillis()

println("Aggregation and join took " + (endTime-startTime)/(1000.0) + " seconds")

val startTimeWindow = System.currentTimeMillis()

val window = Window.partitionBy("x")
val dfWindow = df.withColumn("max_y", max("y").over(window))

println("dfWindow.count= " + dfWindow.count)

val endTimeWindow = System.currentTimeMillis()

println("Window took " + (endTimeWindow-startTimeWindow)/(1000.0) + " seconds")

val diff = dfJoined.except(dfWindow)
diff.show

df.unpersist

The result of running the above code is displayed below.

1 thought on “Using Windows in Spark to Avoid Joins3 min read

Comments are closed.

Discover more from Knoldus Blogs

Subscribe now to keep reading and get access to the full archive.

Continue reading