Efficient Range-Joins With Spark 2.0

If you've ever worked with Spark on any kind of time-series analysis, you probably got to the point where you need to join two DataFrames based on time difference between timestamp fields.

For the purpose of this post, let's assume we have a DataFrame with events data, and another one with measurements (couldn't be more generic than that...). Both have timestamp fields (eventTime and measurementTime), and we want to join every event with the measurements that were recorded in the hour before it.

A naive approach (just specifying this as the range condition) would result in a full cartesian product and a filter that enforces the condition (tested using Spark 2.0). This has a horrible effect on performance, especially if DataFrames are more than a few hundred thousands records.

While Spark guys are working on a more generic solution (see github issue here), there are still use-cases we can greatly improve performance even with the current join strategies that are available. One of them is the one described above (events to measurements from the hour before it), and I believe it's a very common one. In this post I'll briefly go over the suggested implementation that worked for me, and if your use-case is different, you could probably play with that a bit so it addresses yours too.

The Data

To keep everything simple, we'll work with the following dataframes. You can obviously work with your own classes as long as they have a timestamp or a numeric field. This example will use timestamps.

The Naive Approach

If we choose to just join the DataFrames and specify the range condition, we'd get the following:

importorg.apache.spark.unsafe.types.CalendarIntervalvarevents=generateEvents(1000000)varmeasurements=generateMeasurements(1000000)// An example with a timestamp field would look like this:
valres=events.join(measurements,(measurements("measurementTime")>events("eventTime")-CalendarInterval.fromString("interval 30 seconds"))&&(measurements("measurementTime")<=events("eventTime")))// With a numeric field (took the id as an example, this is obviously useless):
valres=events.join(measurements,(measurements("mid")>events("eid")-lit(2))&&(measurements("mid")<=events("eid")))res.explain//runsomethinglike`res.count`tomakeSparkactuallyperformthejoin.

The important thing to look on, is how Spark plans to perform the join we've defined on our two DataFrames. The explain command, both on a timestamp field join, and on a numeric field join, gives us this:

The first row is the key, indicating that Spark is going to resolve our request by performing a cartesian product of the two DataFrames. Notice that if number of records in one of the DataFrames is small enough, Spark will be able to broadcast one of them to all machines and perform a BroadcastNestedLoopJoin and BroadcastExchange. This is better, but isn't considered as a solution as we want to work with large data sets.

The Bucketing, Double-Joining and Filtering Approach

Now let's take advantage of our less generic use-case. We know that we're only interested in measurements that happened up to 60 minutes before the event so basically, every event should only be matched with it's local environment (time based) and a full cartesian product is just a waste of computing effort. We would basically like to group records together and join only groups of records that are close in time.

Let's start with grouping records in both DataFrames by a 60 minutes interval:

If we look on a specific event record, no matter where it's 'located' in its timeframe window (the first minute or the last minute of the 60 minutes window), we can guarantee that all its matching measurements will either be linked to the same window, or the one before:

What we basically want to do is to group all events to timeframes, and link every measurement to its matching window and the one after (marked in red). Then, we can join by the window column and filter for the exact 60 minutes before (as the two frames will give us more than that). The same technique can be applied to numeric fields as well (grouping to windows is actually integer division). The following code does that for both cases:

importscala.util.{Try,Success,Failure}defrange_join_dfs[U,V](df1:DataFrame,rangeField1:Column,df2:DataFrame,rangeField2:Column,rangeBack:Any):Try[DataFrame]={// check that both fields are from the same (and the correct) type
(df1.schema(rangeField1.toString).dataType,df2.schema(rangeField2.toString).dataType,rangeBack)match{case(x1:TimestampType,x2:TimestampType,rb:String)=>truecase(x1:NumericType,x2:NumericType,rb:Number)=>truecase_=>returnFailure(newIllegalArgumentException("rangeField1 and rangeField2 must both be either numeric or timestamps. If they are timestamps, rangeBack must be a string, if numerics, rangeBack must be numeric"))}// returns the "window grouping" function for timestamp/numeric.
// Timestamps will return the start of the grouping window
// Numeric will do integers division
defgetWindowStartFunction(df:DataFrame,field:Column)={df.schema(field.toString).dataTypematch{cased:TimestampType=>window(field,rangeBack.asInstanceOf[String])("start")cased:NumericType=>floor(field/lit(rangeBack))case_=>thrownewIllegalArgumentException("field must be either of NumericType or TimestampType")}}// returns the difference between windows and a numeric representation of "rangeBack"
// if rangeBack is numeric - the window diff is 1 and the numeric representation is rangeBack itself
// if it's timestamp - the CalendarInterval can be used for both jumping between windows and filtering at the end
defgetPrevWindowDiffAndRangeBackNumeric(rangeBack:Any)=rangeBackmatch{caserb:Number=>(1,rangeBack)caserb:String=>{valinterval=rbmatch{caserbifrb.startsWith("interval")=>org.apache.spark.unsafe.types.CalendarInterval.fromString(rb)case_=>org.apache.spark.unsafe.types.CalendarInterval.fromString("interval "+rb)}//( interval.months * (60*60*24*31) ) + ( interval.microseconds / 1000000 )
(interval,interval)}case_=>thrownewIllegalArgumentException("rangeBack must be either of NumericType or TimestampType")}// get windowstart functions for rangeField1 and rangeField2
valrf1WindowStart=getWindowStartFunction(df1,rangeField1)valrf2WindowStart=getWindowStartFunction(df2,rangeField2)val(prevWindowDiff,rangeBackNumeric)=getPrevWindowDiffAndRangeBackNumeric(rangeBack)// actual joining logic starts here
valwindowedDf1=df1.withColumn("windowStart",rf1WindowStart)valwindowedDf2=df2.withColumn("windowStart",rf2WindowStart).union(df2.withColumn("windowStart",rf2WindowStart+lit(prevWindowDiff)))valres=windowedDf1.join(windowedDf2,"windowStart").filter((rangeField2>rangeField1-lit(rangeBackNumeric))&&(rangeField2<=rangeField1)).drop(windowedDf1("windowStart")).drop(windowedDf2("windowStart"))Success(res)}

As you can see, most of it is just the handling of both timestamps and numerics. The logic itself is pretty straight-forward..

Let's look at the execution plan now:

varevents=generateEvents(10000000).toDFvarmeasurements=generateMeasurements(10000000).toDF// you can either join by timestamp fields
varres=range_join_dfs(events,events("eventTime"),measurements,measurements("measurementTime"),"60 minutes")// or by numeric fields (again, id was taken here just for the purpose of the example)
varres=range_join_dfs(events,events("eid"),measurements,measurements("mid"),2)resmatch{caseFailure(ex)=>print(ex)caseSuccess(df)=>df.explain}//andrunsomethinglike`res.count`toactuallyperformanything.

When we run this (with a relatively large DataFrame to avoid broadcasting optimizations) we get the following execution plan (some text was truncated to keep is readable...):

Some sanity check

Events dataframe contained 10 events (one every 10 seconds). Measurements dataframe also contained 10 measurements with around 10 seconds between them. Below is the result of the join for rangeBack="30 seconds" (rows were truncated):

eid

eventTime

eventType

mid

measurementTime

value

3

18:24:28

LoginEvent

6

18:24:02

0.12131363910425985

3

18:24:28

LoginEvent

5

18:24:09

0.12030715258495939

3

18:24:28

LoginEvent

4

18:24:21

0.7604318153406678

4

18:24:18

LoginEvent

7

18:23:52

0.6037143578435027

4

18:24:18

LoginEvent

6

18:24:02

0.12131363910425985

4

18:24:18

LoginEvent

5

18:24:09

0.12030715258495939

5

18:24:08

PurchaseEvent

8

18:23:39

0.1435668838975337

5

18:24:08

PurchaseEvent

7

18:23:52

0.6037143578435027

5

18:24:08

PurchaseEvent

6

18:24:02

0.12131363910425985

...

...

...

...

...

...

Benchmarking

In order to estimate the performance boost, I launched a Google Dataproc cluster of 4 regular machines (plus a master) and tried different sizes of DataFrames. The results are below:

Naive Approach

Efficient Approach

# Rows

Time (secs)

# Rows

Time (secs)

10K

7.85

10K

1.57

20K

7.81

50K

2.48

50K

43.40

100K

3.08

80K

97.73

500K

14.09

130K

265.43

1M

13.06

200K

736.58

2M

17.49

5M

85.52

10M

189.39

50M

252.44

100M

438.15

Results are pretty impressive. I could join dataframes with 50M records each at around the same time it took me to join 130K dataframes in the naive approach. I also tried to join larger dataframes with the naive approach but since I'm being billed by the minute, and it started to take hours, I just gave up...

Notice the x-axis is logarithmic. The performance boost is actually around 3 orders of magnitude.

What if my use-case is different?

Well, depends on what exactly you mean by different. If instead of the last 60 minutes, you want to join anything from 120 to 60 minutes before the event, you could just play with the windows we attach measurements to. If you want to join to future measurements we'll basically have to match measurements to past windows instead of the current and next one. All those changes are pretty easy to do.

If, however, you want to stretch the limits and try to join records where the rangeBack parameter isn't constant (let's say it depends on some field of the event), then you're out of luck with this approach but I hope it at least gave you some ideas...

Hope that helps, and I actually really hope Spark devs will support range joins through Catalyst so we don't need to hack our way to efficient joins.