Spark必知必会 | Spark SQL自定义函数UDF、UDAF聚合函数以及开窗函数的使用
记得关注我们, 一起成长哦
1、Spark SQL自定义函数就是可以通过scala写一个类,然后在SparkSession上注册一个函数并对应这个类,然后在SQL语句中就可以使用该函数了,首先定义UDF函数,那么创建一个SqlUdf类,并且继承UDF1或UDF2等等,UDF后边的数字表示了当调用函数时会传入进来有几个参数,最后一个R则表示返回的数据类型,如下图所示:
2、这里选择继承UDF2,如下代码所示:
package com.udf
import org.apache.spark.sql.api.java.UDF2
class SqlUDF extends UDF2[String,Integer,String] {
override def call(t1: String, t2: Integer): String = {
t1+"_udf_test_"+t2
}
}
3、然后在SparkSession生成的对象上通过sparkSession.udf.register进行注册,如下代码所示:
val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
val sparkSession=SparkSession.builder().config(conf).getOrCreate()
//指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
//第三个参数是返回类型
sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
4、生成模拟数据,并注册一个临时表,如下代码所示:
var rows=Seq[Row]()
val random=new Random()
for(i <- 0 until 10){
val name="name"+i
val age=random.nextInt(30)%15+15
val row=Row(name,age)
rows +:=row
}
val rowsRDD=sparkSession.sparkContext.parallelize(rows)
val schema=DataTypes.createStructType(Array[StructField](
DataTypes.createStructField("name",DataTypes.StringType,true),
DataTypes.createStructField("age",DataTypes.IntegerType,true))
)
val df=sparkSession.createDataFrame(rowsRDD,schema)
df.createOrReplaceTempView("person")
df.show()
输出结果如下图所示:
5、在sql语句中使用自定义函数splicing_t1_t2,然后将函数的返回结果定义一个别名name_age,如下代码所示:
val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"
sparkSession.sql(sql).show()
输出结果如下:
6、由此可以看到在自定义的UDF类中,想如何操作都可以了,完整代码如下;
package com.udf
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField}
import scala.util.Random
object AppUdf {
def main(args:Array[String]):Unit={
val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
val sparkSession=SparkSession.builder().config(conf).getOrCreate()
//指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
//第三个参数是返回类型
sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
var rows=Seq[Row]()
val random=new Random()
for(i <- 0 until 10){
val name="name"+i
val age=random.nextInt(30)%15+15
val row=Row(name,age)
rows +:=row
}
val rowsRDD=sparkSession.sparkContext.parallelize(rows)
val schema=DataTypes.createStructType(Array[StructField](
DataTypes.createStructField("name",DataTypes.StringType,true),
DataTypes.createStructField("age",DataTypes.IntegerType,true))
)
val df=sparkSession.createDataFrame(rowsRDD,schema)
df.createOrReplaceTempView("person")
val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"
sparkSession.sql(sql).show()
sparkSession.close()
}
}
UserDefinedAggregateFunction -
1、它是一个接口,需要实现的方法有:
class AvgAge extends UserDefinedAggregateFunction {
//设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同
override def inputSchema: StructType = ???
//指定缓冲数据的字段与类型
override def bufferSchema: StructType = ???
//指定数据的返回类型
override def dataType: DataType = ???
//指定是否是确定性,对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出
override def deterministic: Boolean = ???
//initialize用户初始化缓存数据
override def initialize(buffer: MutableAggregationBuffer): Unit = ???
//当有新的输入数据时,update就会更新缓存变量
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
//将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
//一个计算方法,用于计算我们的最终结果
override def evaluate(buffer: Row): Any = ???
}
这是一个计算平均年龄的自定义聚合函数,实现代码如下所示:
package com.udf
import java.math.BigDecimal
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
/**
* 用于计算平均年龄的聚合函数
*/
class AvgAge extends UserDefinedAggregateFunction {
/**
* 设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同
* 比如计算平均年龄,输入的是age这一列的数据,注意此处的age名称可以随意命名
* @return
*/
override def inputSchema: StructType = DataTypes.createStructType(Array[StructField](DataTypes.createStructField("age",DataTypes.IntegerType,true)))
/**
* 指定缓冲数据的字段与类型,相当于中间变量
* 由于要计算平均值,首先要计算出总和与个数才能计算平均值,因此需要进来一个值就要累加并计数才能计算出平均值
* 所以要定义两个变量作为累加和以及计数的变量
* @return
*/
override def bufferSchema: StructType = DataTypes.createStructType(Array[StructField](
DataTypes.createStructField("sum",DataTypes.DoubleType,true),
DataTypes.createStructField("count",DataTypes.IntegerType,true)
))
//指定数据的返回类型,由于平均值是double类型,因此定义DoubleType
override def dataType: DataType = DataTypes.DoubleType
/**
* 设置该函数是否为幂等函数
* 幂等函数:即只要输入的数据相同,结果一定相同
* true表示是幂等函数,false表示不是
* @return
*/
override def deterministic: Boolean = true
/**
* initialize用于初始化缓存变量的值,也就是初始化bufferSchema函数中定义的两个变量的值sum,count
* 其中buffer(0)就表示sum值,buffer(1)就表示count的值,如果还有第3个,则使用buffer(3)表示
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0) //或使用buffer(0)=0.0
buffer.update(1,0) //或使用buffer(1)=0
}
/**
* 当有一行数据进来时就会调用update一次,有多少行就会调用多少次,input就表示在调用自定义函数中有多少个参数,最终会将
* 这些参数生成一个Row对象,在使用时可以通过input.getString或inpu.getLong等方式获得对应的值
* 缓冲中的变量sum,count使用buffer(0)或buffer.getDouble(0)的方式获取到
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val sum=buffer.getDouble(0)
val count=buffer.getInt(1)
buffer.update(0,sum+input.getInt(0).toDouble)
buffer.update(1,count+1)
}
/**
* 将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行
* 其中buffer1是本节点上的缓存变量,而buffer2是从其他节点上过来的缓存变量然后转换为一个Row对象,然后将buffer2
* 中的数据合并到buffer1中去即可
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val sum1=buffer1.getDouble(0)
val count1=buffer1.getInt(1)
val sum2=buffer2.getDouble(0)
val count2=buffer2.getInt(1)
buffer1.update(0,sum1+sum2)
buffer1.update(1,count1+count2)
}
/**
* 一个计算方法,用于计算我们的最终结果,也就相当于返回值
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
val bd = new BigDecimal(buffer.getDouble(0)/buffer.getInt(1).toDouble)
bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数
}
}
2、注册该类,并指定到一个自定义函数中,如下图所示:
3、在表中加一列字段id,通过GROUP BY进行分组计算,如
4、在sql语句中使用group_age_avg,如下图所示:
输出结果如下图所示:
5、完整代码如下:
package com.udf
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField}
import scala.util.Random
object AppUdf {
def main(args:Array[String]):Unit={
val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
val sparkSession=SparkSession.builder().config(conf).getOrCreate()
//指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
//第三个参数是返回类型
sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
//UDAF不用设置返回类型,因此使用两个参数即可
sparkSession.udf.register("group_age_avg",new AvgAge)
var rows=Seq[Row]()
val random=new Random()
for(i <- 0 until 10){
val name="name"+i
val age=random.nextInt(30)%15+15
val row=Row(random.nextInt(2),name,age)
rows +:=row
}
val rowsRDD=sparkSession.sparkContext.parallelize(rows)
val schema=DataTypes.createStructType(Array[StructField](
DataTypes.createStructField("id",DataTypes.IntegerType,true),
DataTypes.createStructField("name",DataTypes.StringType,true),
DataTypes.createStructField("age",DataTypes.IntegerType,true))
)
val df=sparkSession.createDataFrame(rowsRDD,schema)
df.createOrReplaceTempView("person")
df.show()
val sql="SELECT id, group_age_avg(age) avg_age FROM person GROUP BY id"
sparkSession.sql(sql).show()
sparkSession.close()
}
}
1、它是一个接口,需要继承与Aggregator,而Aggregator有3个参数,分别是IN,BUF,OUT,IN表示输入的值是什么,可以是一个自定类对象包含多个值,也可以是单个值,BUF就是需要用来缓存值使用的,如果需要缓存多个值也需要定义一个对象,而返回值也可以是一个对象返回多个值,需要实现的方法有:
package com.udf
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
case class DataBuf(var sum:Double,var count:Int)
object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{
/**
* 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的
* @return
*/
override def zero: DataBuf = ???
/**
* reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b
* @param b
* @param a
* @return
*/
override def reduce(b: DataBuf, a: Int): DataBuf = ???
/**
* merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并,
* 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中
* @param b1
* @param b2
* @return
*/
override def merge(b1: DataBuf, b2: DataBuf): DataBuf = ???
/**
* finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值
* 返回值可以是一个对象
* @param reduction
* @return
*/
override def finish(reduction: DataBuf): Double = ???
/**
* 缓冲数据编码方式
* @return
*/
override def bufferEncoder: Encoder[DataBuf] = ???
/**
* 最终数据输出编码方式
* @return
*/
override def outputEncoder: Encoder[Double] = ???
}
2、具体实现如下代码所示:
package com.udf
import java.math.BigDecimal
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
case class DataBuf(var sum:Double,var count:Int)
object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{
/**
* 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的
* @return
*/
override def zero: DataBuf = DataBuf(0.0,0)
/**
* reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b
* @param b
* @param a
* @return
*/
override def reduce(b: DataBuf, a: Int): DataBuf = {
b.count+=1
b.sum+=a.toDouble
b
}
/**
* merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并,
* 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中
* @param b1
* @param b2
* @return
*/
override def merge(b1: DataBuf, b2: DataBuf): DataBuf = {
b1.sum+=b2.sum
b1.count+=b2.count
b1
}
/**
* finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值
* 返回值可以是一个对象
* @param reduction
* @return
*/
override def finish(reduction: DataBuf): Double = {
val bd = new BigDecimal(reduction.sum/reduction.count.toDouble)
bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数
}
/**
* 缓冲数据编码方式,如果Encoder中指定的类型时对象,则设置为product,如果是具体的类型,则需设置为具体的类型
* @return
*/
override def bufferEncoder: Encoder[DataBuf] = Encoders.product
/**
* 最终数据输出编码方式,如果Encoder中指定的类型,则设置为具体的类型,比如Double则设置为scalaDouble
* @return
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
3、而使用此聚合函数就不能通过注册函数来使用了,需要通过Dataset对象的select来使用,如下图所示:
执行结果如下图所示:
因此无类型的用户自定于聚合函数:UserDefinedAggregateFunction和类型安全的用户自定于聚合函数:Aggregator之间的区别是
(1)UserDefinedAggregateFunction不能够带类型而Aggregator是可以带类型的。
(2)使用方法不同UserDefinedAggregateFunction通过注册可以在DataFram的sql语句中使用,而Aggregator必须是在Dataset上使用。
四、开窗函数的使用
1、在Spark 1.5.x版本以后,在Spark SQL和DataFrame中引入了开窗函数,其中比较常用的开窗函数就是row_number该函数的作用是根据表中字段进行分组,然后根据表中的字段排序;其实就是根据其排序顺序,给组中的每条记录添加一个序号;且每组的序号都是从1开始,可利用它的这个特性进行分组取top-n。它是放在select子句中的,其格式为:
ROW_NUMBER() OVER (PARTITION BY area ORDER BY click_count DESC) rank
首先可以,在SELECT查询时,使用row_number()函数,其次row_number()函数后面先跟上OVER关键字,然后括号中,是PARTITION BY,也就是说根据哪个字段进行分组,其次是可以用ORDER BY进行组内排序, 然后row_number()就可以给每个组内的行,一个组内行号,然后rank就是每一组的行号
2、使用方法的sql语句为:
SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc
意思是在sql语句中加一个rank字段,该字段记录了以id为分组,在组内按照age升序排序,并记录行号,最后先按照id降序排序,如果id相同则按照rank降序排序
3、代码如下:
package com.udf
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField}
import scala.util.Random
object AppUdf {
def main(args:Array[String]):Unit={
val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
val sparkSession=SparkSession.builder().config(conf).getOrCreate()
//指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
//第三个参数是返回类型
sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
//UDAF不用设置返回类型,因此使用两个参数即可
sparkSession.udf.register("group_age_avg",new AvgAge)
var rows=Seq[Row]()
val random=new Random()
for(i <- 0 until 10){
val name="name"+i
val age=random.nextInt(30)%15+15
val row=Row(random.nextInt(2),name,age)
rows +:=row
}
val rowsRDD=sparkSession.sparkContext.parallelize(rows)
val schema=DataTypes.createStructType(Array[StructField](
DataTypes.createStructField("id",DataTypes.IntegerType,true),
DataTypes.createStructField("name",DataTypes.StringType,true),
DataTypes.createStructField("age",DataTypes.IntegerType,true))
)
val df=sparkSession.createDataFrame(rowsRDD,schema)
df.createOrReplaceTempView("person")
df.show()
val sql="SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc"
sparkSession.sql(sql).show()
sparkSession.close()
}
}
输出结果如下: