Points of Interest: User Defined Aggregate Functions In Spark Dataframes

PlaceIQ > Blog  > Points of Interest: User Defined Aggregate Functions In Spark Dataframes

Points of Interest: User Defined Aggregate Functions In Spark Dataframes

By Paul Brenner, Data Scientist, PlaceIQ

 

Why hello there friend. Here at PlaceIQ there is nothing that the Data Science team loves more than Scala, Spark, and Zeppelin notebooks [1]. From there, things get murkier. Some of our data scientists fear the glorious shining light of the future (DataFrames) and still swim primarily in the muddy pools of the past (RDDs). No thank you. This post is exclusively about DataFrames.

 

According to a quick sampling of the internet, the normal thing to do is to write an introductory post that is broadly applicable to anyone starting out with a technology. Guess what? This post isn’t that. If you need an intro to DataFrames you are just going to have to go here, here, maybe here, or I don’t know, basically anywhere else on the Google except this post. Instead I want to dive headfirst into the dark scary depths and introduce a topic that is not well documented: User Defined Aggregate Functions aka UDAFs.

 

UDAFs are functions that can be called during a groupBy to calculate… something… about the rows in each group. The benefit of learning to write UDAFs is obvious: it allows you to use UDAFs [2]. Actually, for most people that is not the benefit of learning to write UDAFs. Honestly, if you came here to learn about UDAFs because you are trying to use groupBy and want to do something more than a simple count or sum of the rows then stop everything, go to the org.apache.spark.sql.functions docs page and search for aggregate functions because what you are trying to do might already be a built in function. Actually, I know people are probably skimming so this calls for a step by step list to make things clear.

 

Steps to UDAFs

 

  1. Identify a use case where you are using groupBy and want to calculate something about the rows in each group. Maybe you want to find the size, max, approximate number of distinct items, a list of each item combined into a sequence, first item, or the correlation coefficient.
  2. Go to this link, hit command-F, and type in “Aggregate Function”.
  3. Check all the hits and see if any of these built in functions solve your problem.
  4. They probably do so seriously should you even still be reading this far down in the steps? Everything I mentioned in step 1 is already covered in a built in function.
  5. Oh, you are fancy and want to do something that isn’t covered, huh? Well, friend, welcome to UDAFs.

 

Not using UDAFs

 

If you followed the steps above and realized that what you want to do is actually covered by a built in function then: congratulations, you are about to get a lot more work done today than the poor fools who actually need to write UDAFs. If you are anything like me then you just want to see an example so that you can copy and paste and be on your way. Well, here you go:

val myDF = Seq(("A",1),
               ("A",2),
               ("A",3),
               ("B",2),
               ("B",3),
               ("B",4)).toDF("stupid_letter","dumb_number")

val result = myDF.groupBy("stupid_letter").agg(max("dumb_number").as("max_of_group"))
z.show(result) //oh, what, you don't use zeppelin? you don't know the magic of z.show? You just use result.show()? No friend, No. Stop that.

This time actually using UDAFs

 

Well, somehow despite my best efforts you still want to learn about UDAFs. They are actually pretty cool! Even if you don’t need them that often, learning to write UDAFs means learning a lot of fun details about working with DataFrames that you probably have been avoiding. So let’s dig in with, of course, a completely oversimplified example: collecting the contents of two columns to a list. You could do this other ways, but easing in will make things more clear. First, those code!

val AggPairsToList = new UserDefinedAggregateFunction {
  // Input Data Type Schema
  def inputSchema: StructType = StructType(
    Array(StructField("id", StringType),StructField("feature", StringType)))

  // Intermediate Schema
  def bufferSchema = StructType(
    Array(
    StructField("pair_entry", ArrayType(
        StructType(Array(
            StructField("id", StringType),
            StructField("feature", StringType)))))))

  // Returned Data Type .
  def dataType = ArrayType(
    StructType((Array(
        StructField("id", StringType),
        StructField("feature", StringType)))))

  // Self-explaining
  def deterministic = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0,Array.empty[(String,String)])
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    var tempArray = new ListBuffer[Tuple2[String,String]]()
    tempArray ++= buffer.getAs[List[Tuple2[String,String]]](0)
    val inputValues : Tuple2[String,String] = (input.getString(0),input.getString(1))
    tempArray += inputValues
    buffer.update(0,tempArray)
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    var tempArray = new ListBuffer[Tuple2[String,String]]()
    tempArray ++= buffer1.getAs[List[Tuple2[String,String]]](0)
    tempArray ++= buffer2.getAs[List[Tuple2[String,String]]](0)
    buffer1.update(0,tempArray)
  }

  // Called after all the entries are exhausted.
    def evaluate(buffer: Row) = {
        var tempArray = new ListBuffer[Tuple2[String,String]]()
        tempArray ++= buffer.getAs[WrappedArray[Row]](0)
                            .filter(x => x(0) != null && x(1) != null)
                            .map{case Row(id:String, result:String) => (id,result)}
        tempArray.take(1000)
  }
}

First, shout outs to the giants whose shoulders I stood on to write this:

 

Ok, now to take things step by step:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType, 
                                   DoubleType, LongType, ArrayType}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, 
                                         UserDefinedAggregateFunction}
import org.apache.spark.sql.Row

import scala.collection.mutable.WrappedArray
import scala.collection.mutable.{ListBuffer,ArrayBuffer}

 

Don’t you hate when people write example code and then don’t include imports and then you have to dig out which package they are in? The worst! Well I’m not the worst, I want you to have these imports.

 

val AggPairsToList = new UserDefinedAggregateFunction {
    // Input Data Type Schema
    def inputSchema: StructType = StructType(
        Array(
            StructField("id", StringType),
            StructField("feature", StringType)))

    // Intermediate Schema
    def bufferSchema: StructType  = StructType(
        Array(
            StructField("pair_entry", ArrayType(
                StructType(Array(
                    StructField("id", StringType),
                    StructField("feature", StringType)))))))

    // Returned Data Type .
    def dataType: StructType  = ArrayType(
        StructType(Array(
            StructField("id", StringType),
            StructField("feature", StringType)))))

Ok, now we are getting to something interesting! Look how ugly that is! Here we are defining inputSchema (what our data looks like coming in), bufferSchema (what the data looks like while we are holding onto it during aggregation), and dataType (what we return). These look pretty scary. How about some words about each:

 

inputSchema

 

Your input data usually comes slapped together in some sort of an StructType. Really all you need here is to replace the … in this: StructType(Array( ... )).
The ... is going to be some comma separated StructFields. If you just have a simple Long coming in then go ahead and make up a name for it like “sweetLongName” and then your ... is just

 

StructField("sweetLongName", LongType)

 

If you have two columns coming in, perhaps two strings, then your ... would look like mine:

 

StructField("id", StringType),StructField("feature", StringType)

 

That doesn’t seem so bad, all you are doing is giving each item a name and a type that Spark is familiar with (like StringType,LongType, or ArrayType)

 

bufferSchema

 

This one is only slightly more complicated. We still use StructType(Array( ... )) but the ... is a smidge more complicated. Now instead of a StringType or a LongType we have an ArrayType. This is convenient because the bufferSchema is the schema we use while the groupBy is running. Basically the date point from each row that is in the group gets added to this bufferSchema. If you were just finding the sum then you wouldn’t need to hang onto each individual value throughout the groupBy process, you would just need to hang onto the running sum so your bufferSchema would just be our familiar friend:

 

StructField("sweetLongName", LongType)

 

But for this function we actually want to hold onto each individual value so we collect them in an array… and these are spark DataFrames so why not collect them in the same format they are coming in:

 

StructField("id", StringType),StructField("feature", StringType)

 

We are going to need an array of these so this time our ... is:

 

StructField("pair_entry", ArrayType(StructType(Array(StructField("id", StringType),StructField("feature", StringType)))))

 

dataType (Returned Data Type)

 

Luckily, if you were doing something like just taking the max of longs then your output dataType would be simple:

 

LongType

 

In this case we want to return an array of the items that we’re collecting so we just return the type portion of the StructField in our ... above:

 

ArrayType(StructType((Array(StructField("id", StringType),StructField("feature", StringType)))))

 

Reality

 

Honestly, this looks really annoying and it absolutely would be if you were writing these entirely from a blank page each time. Don’t do that. Once you get the hang of it, the majority of writing a UDAF is copying and pasting and then tweaking the names/data types inside the above fields. Now it is time to actually start ingesting data. Onto the next code block!

 

// Self-explaining
def deterministic = true

// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
  buffer.update(0,Array.empty[(String,String)])
}

// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
  var tempArray = new ListBuffer[Tuple2[String,String]]()
  tempArray ++= buffer.getAs[List[Tuple2[String,String]]](0)
  val inputValues : Tuple2[String,String] = (input.getString(0),input.getString(1))
  tempArray += inputValues
  buffer.update(0,tempArray)
}

// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
  var tempArray = new ListBuffer[Tuple2[String,String]]()
  tempArray ++= buffer1.getAs[List[Tuple2[String,String]]](0)
  tempArray ++= buffer2.getAs[List[Tuple2[String,String]]](0)
  buffer1.update(0,deviceArray)
}

 

I really like that in the sample code I followed deterministic = true is labeled as “self explaining”. Basically, we aren’t going to go into this mess here. If you want your UDAF not to be deterministic then you are going to have to dig deeper than this (already too long) post.

 

// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
  buffer.update(0,Array.empty[(String,String)])
}

So here we are at initialization. The UDAF just hit a new group it hasn’t seen before and it needs to know how to set things up. Well we are storing everything in a MutableAggregationBuffer so all we need to do is add an empty version of the data type we are working with to the first element of the buffer and return that. For this step when you are creating your own UDAF all you really need to do is play with:

Array.empty[(String,String)]

 

In my case I’m aggregating my data into an Array (not into a Long like in the sum function mentioned above) and inside the Array is a Tuple2 of both Strings. Maybe you just want to keep an Array of Longs? Then all you would need is
Array.empty[Long]

See! Just adjust to suit your needs.

 

// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
  var tempArray = new ListBuffer[Tuple2[String,String]]()
  tempArray ++= buffer.getAs[List[Tuple2[String,String]]](0)
  val inputValues : Tuple2[String,String] = (input.getString(0),input.getString(1))
  tempArray += inputValues
  buffer.update(0,tempArray)
}

// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
  var tempArray = new ListBuffer[Tuple2[String,String]]()
  tempArray ++= buffer1.getAs[List[Tuple2[String,String]]](0)
  tempArray ++= buffer2.getAs[List[Tuple2[String,String]]](0)
  buffer1.update(0,tempArray)
}

 

Now there are two questions that need to be answered: how does the UDAF add a new row of data into whatever it is currently holding onto (update) and how does it merge the buffer of data it has collected on one executor with the buffer of data on another (merge). Turns out, unless you have a fancy tweak you want to make here, then this again is just playing with datatypes. In each example first we create a tempArray of type ListBuffer and include what data type is coming in. In this case it is a Tuple2 of Strings. Next we take the data that is already in the buffer, pull it out into an easy to manipulate format, and throw it into the buffer

 

tempArray ++= buffer.getAs[List[Tuple2[String,String]]](0)

 

Have you used this much in spark DataFrames? If you have a row sometimes it can be really annoying to get data out of it. getAs is your friend here. All you need is:

 

row.getAs[DataType]("column_name" OR column number)

 

In this case everything is just in the first column so I end with (0). Since my datatype is really simple for input I can just use .getString(0) and .getString(1) to get the first and second columns. Nice. Combine that all into tempArray and then replace the contents of buffer. Done updating… and really done merging too… is there anything surprising in there? I don’t see it.
Last step!

 

// Called after all the entries are exhausted.
  def evaluate(buffer: Row) = {
      var tempArray = new ListBuffer[Tuple2[String,String]]()
      tempArray ++= buffer.getAs[WrappedArray[Row]](0)
                          .filter(x => x(0) != null && x(1) != null)
                          .map{case Row(id:String, result:String) => (id,result)}
      tempArray.take(1000)
}

So once all the rows are exhausted and all the data merged the evaluate function lets you manipulate your buffer and return something. The only challenging part here is extracting your data so that you can work with it. If we ignore the filter (which should be straightforward) we have

 

tempArray ++= buffer.getAs[WrappedArray[Row]](0).map{case Row(id:String, result:String) => (id,result)}

 

Weird. So our data type is a WrappedArray[Row]. Sorry, that is just what you get. As long as you .getAs[correct data type](0) then your data will be freed from its wrapped prison and you can go on with your life. Even easier you can just map Row(id:String, result:String) to a Tuple and now everything is super simple. Basically in this example I’m doing NOTHING with the data besides extracting it (oh and taking only 1000 rows because I don’t want to blow up my results later on). Usually all the fun happens in this evaluate stage though! Let’s see one more example where something fun happens!
In this UDAF I have a bunch of flatmapped data. I am grouping on an id and then want to return 3 demographic scores based on key (demographic category) value (demographic score) pairs. So assume the input is something like:

 

id, demographic category1-1, .5
id, demographic category1-2, .1
id, demographic category1-3, .8
id, demogrpahic category2-1, 0.7
id, demogrpahic category2-2, 0.3
id, demographic category3, 0.1
id, demographic category 4 which I don't care about, 0.8

and I want to just get (based on some secret formula):

id, demographic category 1, 3
id, demographic category 2, 1
id, demographic category 3, 5

So here is the code:

val AggDemographicBins = new UserDefinedAggregateFunction {
  // Input Data Type Schema
  def inputSchema: StructType = StructType(
      Array(
          StructField("segment_name", StringType),
          StructField("score", DoubleType)))

  // Intermediate Schema
  def bufferSchema = StructType(
      Array(
          StructField("demographic_and_score", ArrayType(
              StructType(Array(
                  StructField("segment_name", StringType),
                  StructField("score", DoubleType)))))))

  // Returned Data Type .
  case class demographicBuckets(Age:Integer, Income: Integer, Ethnicity:Integer)
  def dataType = StructType(
      Array(
            StructField("AgeBucket",IntegerType),
            StructField("IncomeBucket",IntegerType),
            StructField("EthnicityBucket",IntegerType)))

  // Self-explaining
  def deterministic = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0,Array.empty[(String,Double)])
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    var deviceArray = new ListBuffer[Tuple2[String,Double]]()
    deviceArray ++= buffer.getAs[List[Tuple2[String,Double]]](0)
    val inputValues : Tuple2[String,Double] = (input.getString(0),input.getDouble(1))
    deviceArray += inputValues
    buffer.update(0,deviceArray)
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    var deviceArray = new ListBuffer[Tuple2[String,Double]]()
    deviceArray ++= buffer1.getAs[List[Tuple2[String,Double]]](0)
    deviceArray ++= buffer2.getAs[List[Tuple2[String,Double]]](0)
    buffer1.update(0,deviceArray)
  }

  def bucketDemographicsFunc(demographic: ListBuffer[String], score: ListBuffer[Double]): (Integer,Integer,Integer) = {
      val demoWScore = demographic.zip(score)

      val ageBucket1 = Seq("Segments->Demographic->Age->GROUP1")
      val ageBucket3 = Seq("Segments->Demographic->Age->GROUP2")
      val ageBucket4 = Seq("Segments->Demographic->Age->GROUP3")
      val ageBucket2 = Seq("Segments->Demographic->Age->GROUP4")
      val ageBucket5 = Seq("Segments->Demographic->Age->GROUP5")

      val ageBucket1score = (demoWScore.filter({case (x,_) => 
            ageBucket1.contains(x)}).map(_._2).sum,1)
      val ageBucket2score = (demoWScore.filter({case (x,_) => 
            ageBucket2.contains(x)}).map(_._2).sum,2)
      val ageBucket3score = (demoWScore.filter({case (x,_) => 
            ageBucket3.contains(x)}).map(_._2).sum,3)
      val ageBucket4score = (demoWScore.filter({case (x,_) => 
            ageBucket4.contains(x)}).map(_._2).sum,4)
      val ageBucket5score = (demoWScore.filter({case (x,_) => 
            ageBucket5.contains(x)}).map(_._2).sum,5)

      val finalAgeBucket = Seq(ageBucket1score,
                               ageBucket2score,
                               ageBucket3score,
                               ageBucket4score,
                               ageBucket5score)
                           .maxBy(_._1) match{
          case (score,bucket) if score == 0 => 0
          case (score,bucket) if score > 0 => bucket
      }

      val incomeBucket1 = Seq("Segments->Demographic->Income->GROUP1")
      val incomeBucket2 = Seq("Segments->Demographic->Income->GROUP2")
      val incomeBucket3 = Seq("Segments->Demographic->Income->GROUP3")
      val incomeBucket4 = Seq("Segments->Demographic->Income->GROUP4")
      val incomeBucket5 = Seq("Segments->Demographic->Income->GROUP55",
                              "Segments->Demographic->Income->GROUP56")

      val incomeBucket1score = (demoWScore.filter({case (x,_) => 
          incomeBucket1.contains(x)}).map(_._2).sum,1)
      val incomeBucket2score = (demoWScore.filter({case (x,_) => 
          incomeBucket2.contains(x)}).map(_._2).sum,2)
      val incomeBucket3score = (demoWScore.filter({case (x,_) => 
          incomeBucket3.contains(x)}).map(_._2).sum,3)
      val incomeBucket4score = (demoWScore.filter({case (x,_) => 
          incomeBucket4.contains(x)}).map(_._2).sum,4)
      val incomeBucket5score = (demoWScore.filter({case (x,_) => 
          incomeBucket5.contains(x)}).map(_._2).sum,5)

      val finalIncomeBucket = Seq(incomeBucket1score,
                                  incomeBucket2score,
                                  incomeBucket3score,
                                  incomeBucket4score,
                                  incomeBucket5score)
                              .maxBy(_._1) match{
          case (score,bucket) if score == 0 => 0
          case (score,bucket) if score > 0 => bucket
      }

      val finalEthnicityBucket = if(
               demoWScore.filter({case (x,_) => 
               x == "Segments->Demographic->Race->GROUPEVALUATOR"})
               .length == 0
         ){0}
         else {
               demoWScore.filter({case (x,_) => 
                   x == "Segments->Demographic->Race->GROUPEVALUATOR"})(0)
               match{
                     case (demo,score) if score < 5.0 => 1
                     case (demo,score) if score < 7.0 => 2
                     case (demo,score) if score < 8.5 => 3
                     case (demo,score) if score < 9.5 => 4
                     case (demo,score) if score >= 9.5 => 5
        }
     }
      (finalAgeBucket,finalIncomeBucket,finalEthnicityBucket)
  }

  // Called after all the entries are exhausted.
    def evaluate(buffer: Row) = {
        var demoAndScoreArray = new ListBuffer[Tuple2[String,Double]]()
        demoAndScoreArray ++= buffer.getAs[WrappedArray[Row]](0)
            .filter(x => x(0) != null && x(1) != null)
            .map{case Row(demographic: String, score: Double) => 
            (demographic,score)}

        val buckets = bucketDemographicsFunc(
            demoAndScoreArray.map(x=>x._1),
            demoAndScoreArray.map(x=>x._2))
        demographicBuckets(buckets._1,buckets._2,buckets._3)
  }
}

 

Wow! That is more code! But if you look, there are actually only minor tweaks to the datatypes for input, buffer, and return. Initialize, merge, and update are basically the same. The only difference is that I created a big (ugly) bucketDemographicsFunc. Humor how non-compact it is, this code was written to be quickly modified instead of to look pretty. If you look at evaluate it should look familiar too! I extract the data into a tuple, same as before, pass it into a function, and then return it in a fancy case class. That is kind of fun, maybe take a look at that if you want to return multiple columns, we aren’t talking about that though. What you should see here is that once everything in your group is aggregated you can just toss it into a function and have it spit out whatever result you want. Incredible! Now basically ALL of your amazing programing skills can be put to use on grouped data.

 

Think of the possibilities!
And then probably realize that whatever you were trying to do could have been done with a built in function. Go look at that doc page again.

 


  1. This is 100% a lie, we also love Hint Water, Lofted Coffee, Kind bars… really all the free food from kitchen… and that is just things we love in the kitchen. Think of all the great stuff outside that we could love… like uhhhhh…. Fresh Air? Mountains? I don’t know  ↩
  2. The astute reader will see that this definition is 100% useless  ↩