1 官方文档

Spark UDAF 官方文档

1.1 定义

User-Defined Aggregate Functions (UDAFs) are user-programmable routines that act on multiple rows at once and return a single aggregated value as a result. This documentation lists the classes that are required for creating and registering UDAFs. It also contains examples that demonstrate how to define and register UDAFs in Scala and invoke them in Spark SQL.

用户自定义聚合函数是用户自己写的函数,一般用来作用在多行输入返回一个聚合值。此文档列举了在创建&注册UDAF必须要注意的点,也包含了一些示例来展示如何在Scala中创建&注册UDAF,并且在SparkSQL中调用它们。

1.2 聚合器(Aggregator)[-IN,BUF,OUT]

A base class for user-defined aggregations, which can be used in Dataset operations to take all of the elements of a group and reduce them to a single value.
列举了在DS中用来输入一个分组中的多行数据 或者多行数据中的某些字段,聚合后返回单一聚合值的必须声明的点。

IN :UDAF的输入类型

BUF :UDAF 中间结果临时缓存值

OUT :UDAF聚合后的输出值类型

  • bufferEncoder:encoder[BUF]

Specifies the Encoder for the intermediate value type.
指定中间值结果的编码器

  • finish(reducation:BUF):OUT

Transform the output of the reduction.
聚合端转换输出值

  • merge(b1: BUF, b2: BUF): BUF

Merge two intermediate values.
对中间结果进行合并(合并后也可以再次合并)

  • outputEncoder: Encoder[OUT]

Specifies the Encoder for the final output value type.
指定最终输出结果的编码器

  • reduce(b: BUF, a: IN): BUF

Aggregate input value a into current intermediate value. For performance, the function may modify b and return it instead of constructing new object for b.
聚合输入值到中间结果。表面上看,这个reduce是在原来的中间结果上修改而不是创建一个新的中间结果

  • zero: BUF

The initial value of the intermediate result for this aggregation.
中间结果的初始值

1.3 示例

1.3.1 类型安全的UDAF

User-defined aggregations for strongly typed Datasets revolve around the Aggregator abstract class. For example, a type-safe user-defined average can look like:
针对强类型的DS,UDAF直接继承Aggregator抽象类。例如求均值的UDAF:

package Function.UDAF.Class

import org.apache.spark.sql.{Encoder,Encoders,SparkSession}
import org.apache.spark.sql.expressions.Aggregator

case class Employee(name:String,salay:Long)
case class Average(var sum:Long,var count:Long)

object MyAverage extends Aggregator[Employee,Average,Double] {

  //中间BUF的初始值,必须确保 any b + zero = b 这一性质
  override def zero:Average = Average(0L,0L)

  //累加初始值到中间值,表面上看是在原来的中间值上直接修改而不是创建新的
  override def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salay
    buffer.count += 1
    buffer
  }

  //合并两个中间值
  override def merge(buffer1: Average, buffer2: Average): Average = {
    buffer1.count += buffer2.count
    buffer1.sum += buffer2.sum
    buffer1
  }

  //转行中间值到最终结果
  override def finish(reduction: Average): Double = {
    reduction.sum / reduction.count
  }

  //指定中间结果类型的编码器
  //Encoder: Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
  override def bufferEncoder: Encoder[Average] = {
    //An encoder for Scala's product type (tuples, case classes, etc).
    Encoders.product
  }

  //指定最终结果类型的编码器
  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }

}

package Function.UDAF.Class

import org.apache.spark.sql.types.{LongType, StringType}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, SparkSession, TypedColumn}

object TestMyAverage {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()

    import spark.implicits._
    //通过Seq创建dataframe
    val dataframe: DataFrame = spark.createDataFrame(Seq(
      ("Michael", 3000),
      ("Andy", 4500),
      ("Justin", 3500),
      ("Berta", 4000)
    ))toDF("name","salay")

    val dataSet: Dataset[Employee] = dataframe.as[Employee]
    dataSet.show()

    //Convert the function to a `TypedColumn` and give it a name
    val average_salary: TypedColumn[Employee, Double] = MyAverage.toColumn.name("average_salary")

    dataSet.select(average_salary).show()
//    +--------------+
//    |average_salary|
//    +--------------+
//    |        3750.0| 
//    +--------------+

  }

}




1.3.2 无确定数据结构的UDAF

上面针对的是DataSet ,有确定的数据结构。对于Row类型的DataFrame 定义上有一些区别
Typed aggregations, as described above, may also be registered as untyped aggregating UDFs for use with DataFrames. For example, a user-defined average for untyped DataFrames can look like:
上面同样的求均值的UDAF作用在dataframe上:

spark 2.x

继承UserDefinedAggregateFunction

package Function.UDAF.Class

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}


//Spark 2.x 对dataframe UDAF用 继承UserDefinedAggregateFunction这种方式
object MyAverageDF2 extends UserDefinedAggregateFunction{

  //输入参数的数据类型
  override def inputSchema: StructType = StructType(StructField("salay",LongType) :: Nil)

  //缓冲区数据类型(count,sum)
  override def bufferSchema: StructType = StructType(StructField("count",LongType) :: StructField("sum",LongType) :: Nil)

  //返回值 最终结果数据类型
  override def dataType: DataType = DoubleType

  //幂等性 (输入相同的值 ,得到相同的结果)
  override def deterministic: Boolean = true

  //缓冲区的初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //count
    buffer(0) = 0L
    //sum
    buffer(1) = 0L
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return

    buffer(0) = 1L + buffer.getLong(0)

    //input 是Row类型的 通过getXXX  获取具体的值
    buffer(1) = input.getLong(0) + buffer.getLong(0)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    //默认buffer1是返回值
  }

  //中间结果转换成最终输出
  override def evaluate(buffer: Row): Any = {
    (buffer.getLong(1) / buffer.getLong(0)).toDouble
  }
}

/*
    * DataFrame use UDAF
    * */
//   注册UDAF
    spark.udf.register("myAverageDF2", MyAverageDF2)
    dataframe.createOrReplaceTempView("dataframe")

    spark.sql(
      """
        |
        |select myAverageDF2(salay)
        |from dataframe
        |
        |""".stripMargin).show()
//      +------------------------------------+
//      |myaveragedf2$(CAST(salay AS BIGINT))|
//      +------------------------------------+
//      |                              3751.0|
//      +------------------------------------+

spark 3.x

继承Aggregator

package Function.UDAF.Class

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

//spark 3.x 之后用这种方式
object MyAverageDF extends Aggregator[Long,Average,Double]{
  //中间结果初始值,必须确保any b + zero = b
  override def zero: Average = Average(0L,0L)

  //输入值与中间结果的合并,看上去是直接在中间结果修改而不是创建新的
  override def reduce(buffer: Average, data: Long): Average = {
    buffer.sum += data
    buffer.count += 1
    buffer
  }
  //中间结果的合并 也可能再次合并
  override def merge(buffer1: Average, buffer2: Average): Average = {
    buffer1.count += buffer2.count
    buffer1.sum += buffer2.sum
    buffer1
  }
  //中间结果 转换 最终输出结果
  override def finish(buffer: Average): Double = {
    buffer.sum / buffer.count
  }

  //buffer 中间结果的数据结构
  override def bufferEncoder: Encoder[Average] = {
    Encoders.product
  }

  //最终结果
  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}

// Register the function to access it on spark3.x
spark.udf.register("myAverage", functions.udaf(MyAverageDF))

更多推荐

Spark UDAF 翻译官方文档