查看原文
其他

Spark SQL自定义函数UDF、UDAF聚合函数以及开窗函数的使用

点击上方蓝色字体,选择“设为星标

回复”资源“获取更多资源

大数据技术与架构点击右侧关注,大数据开发领域最强公众号!

暴走大数据点击右侧关注,暴走大数据!

一、UDF的使用


1、Spark SQL自定义函数就是可以通过scala写一个类,然后在SparkSession上注册一个函数并对应这个类,然后在SQL语句中就可以使用该函数了,首先定义UDF函数,那么创建一个SqlUdf类,并且继承UDF1或UDF2等等,UDF后边的数字表示了当调用函数时会传入进来有几个参数,最后一个R则表示返回的数据类型,如下图所示:


2、这里选择继承UDF2,如下代码所示:

package com.udf import org.apache.spark.sql.api.java.UDF2 class SqlUDF extends UDF2[String,Integer,String] { override def call(t1: String, t2: Integer): String = { t1+"_udf_test_"+t2 }}

3、然后在SparkSession生成的对象上通过sparkSession.udf.register进行注册,如下代码所示:

val conf=new SparkConf().setAppName("AppUdf").setMaster("local")val sparkSession=SparkSession.builder().config(conf).getOrCreate()//指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类//第三个参数是返回类型sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)

4、生成模拟数据,并注册一个临时表,如下代码所示:

var rows=Seq[Row]() val random=new Random() for(i <- 0 until 10){ val name="name"+i val age=random.nextInt(30)%15+15 val row=Row(name,age) rows +:=row } val rowsRDD=sparkSession.sparkContext.parallelize(rows) val schema=DataTypes.createStructType(Array[StructField]( DataTypes.createStructField("name",DataTypes.StringType,true), DataTypes.createStructField("age",DataTypes.IntegerType,true)) ) val df=sparkSession.createDataFrame(rowsRDD,schema) df.createOrReplaceTempView("person") df.show()

输出 结果如下图所示:

5、在sql语句中使用自定义函数splicing_t1_t2,然后将函数的返回结果定义一个别名name_age,如下代码所示:

val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"sparkSession.sql(sql).show()

输出结果如下:

6、由此可以看到在自定义的UDF类中,想如何操作都可以了,完整代码如下;

package com.udf import org.apache.spark.SparkConfimport org.apache.spark.sql.{Row, SparkSession}import org.apache.spark.sql.types.{DataTypes, StructField} import scala.util.Random object AppUdf { def main(args:Array[String]):Unit={ val conf=new SparkConf().setAppName("AppUdf").setMaster("local") val sparkSession=SparkSession.builder().config(conf).getOrCreate() //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类 //第三个参数是返回类型 sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType) var rows=Seq[Row]() val random=new Random() for(i <- 0 until 10){ val name="name"+i val age=random.nextInt(30)%15+15 val row=Row(name,age) rows +:=row } val rowsRDD=sparkSession.sparkContext.parallelize(rows) val schema=DataTypes.createStructType(Array[StructField]( DataTypes.createStructField("name",DataTypes.StringType,true), DataTypes.createStructField("age",DataTypes.IntegerType,true)) ) val df=sparkSession.createDataFrame(rowsRDD,schema) df.createOrReplaceTempView("person") val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person" sparkSession.sql(sql).show() sparkSession.close() }}

二、无类型的用户自定于聚合函数:UserDefinedAggregateFunction

1、它是一个接口,需要实现的方法有:

class AvgAge extends UserDefinedAggregateFunction { //设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同 override def inputSchema: StructType = ??? //指定缓冲数据的字段与类型 override def bufferSchema: StructType = ??? //指定数据的返回类型 override def dataType: DataType = ??? //指定是否是确定性,对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出 override def deterministic: Boolean = ??? //initialize用户初始化缓存数据 override def initialize(buffer: MutableAggregationBuffer): Unit = ??? //当有新的输入数据时,update就会更新缓存变量 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ??? //将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ??? //一个计算方法,用于计算我们的最终结果 override def evaluate(buffer: Row): Any = ???}

这是一个计算平均年龄的自定义聚合函数,实现代码如下所示:

package com.udf import java.math.BigDecimal import org.apache.spark.sql.Rowimport org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} /** * 用于计算平均年龄的聚合函数 */class AvgAge extends UserDefinedAggregateFunction { /** * 设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同 * 比如计算平均年龄,输入的是age这一列的数据,注意此处的age名称可以随意命名 * @return */ override def inputSchema: StructType = DataTypes.createStructType(Array[StructField](DataTypes.createStructField("age",DataTypes.IntegerType,true))) /** * 指定缓冲数据的字段与类型,相当于中间变量 * 由于要计算平均值,首先要计算出总和与个数才能计算平均值,因此需要进来一个值就要累加并计数才能计算出平均值 * 所以要定义两个变量作为累加和以及计数的变量 * @return */ override def bufferSchema: StructType = DataTypes.createStructType(Array[StructField]( DataTypes.createStructField("sum",DataTypes.DoubleType,true), DataTypes.createStructField("count",DataTypes.IntegerType,true) )) //指定数据的返回类型,由于平均值是double类型,因此定义DoubleType override def dataType: DataType = DataTypes.DoubleType /** * 设置该函数是否为幂等函数 * 幂等函数:即只要输入的数据相同,结果一定相同 * true表示是幂等函数,false表示不是 * @return */ override def deterministic: Boolean = true /** * initialize用于初始化缓存变量的值,也就是初始化bufferSchema函数中定义的两个变量的值sum,count * 其中buffer(0)就表示sum值,buffer(1)就表示count的值,如果还有第3个,则使用buffer(3)表示 * @param buffer */ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0,0.0) //或使用buffer(0)=0.0 buffer.update(1,0) //或使用buffer(1)=0 } /** * 当有一行数据进来时就会调用update一次,有多少行就会调用多少次,input就表示在调用自定义函数中有多少个参数,最终会将 * 这些参数生成一个Row对象,在使用时可以通过input.getString或inpu.getLong等方式获得对应的值 * 缓冲中的变量sum,count使用buffer(0)或buffer.getDouble(0)的方式获取到 * @param buffer * @param input */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val sum=buffer.getDouble(0) val count=buffer.getInt(1) buffer.update(0,sum+input.getInt(0).toDouble) buffer.update(1,count+1) } /** * 将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行 * 其中buffer1是本节点上的缓存变量,而buffer2是从其他节点上过来的缓存变量然后转换为一个Row对象,然后将buffer2 * 中的数据合并到buffer1中去即可 * @param buffer1 * @param buffer2 */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { val sum1=buffer1.getDouble(0) val count1=buffer1.getInt(1) val sum2=buffer2.getDouble(0) val count2=buffer2.getInt(1) buffer1.update(0,sum1+sum2) buffer1.update(1,count1+count2) } /** * 一个计算方法,用于计算我们的最终结果,也就相当于返回值 * @param buffer * @return */ override def evaluate(buffer: Row): Any = { val bd = new BigDecimal(buffer.getDouble(0)/buffer.getInt(1).toDouble) bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数 }}

2、注册该类,并指定到一个自定义函数中,如下图所示:

3、在表中加一列字段id,通过GROUP BY进行分组计算,如

4、在sql语句中使用group_age_avg,如下图所示:

输出结果如下图所示:

5、完整代码如下:

package com.udf import org.apache.spark.SparkConfimport org.apache.spark.sql.{Row, SparkSession}import org.apache.spark.sql.types.{DataTypes, StructField} import scala.util.Random object AppUdf { def main(args:Array[String]):Unit={ val conf=new SparkConf().setAppName("AppUdf").setMaster("local") val sparkSession=SparkSession.builder().config(conf).getOrCreate() //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类 //第三个参数是返回类型 sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType) //UDAF不用设置返回类型,因此使用两个参数即可 sparkSession.udf.register("group_age_avg",new AvgAge) var rows=Seq[Row]() val random=new Random() for(i <- 0 until 10){ val name="name"+i val age=random.nextInt(30)%15+15 val row=Row(random.nextInt(2),name,age) rows +:=row } val rowsRDD=sparkSession.sparkContext.parallelize(rows) val schema=DataTypes.createStructType(Array[StructField]( DataTypes.createStructField("id",DataTypes.IntegerType,true), DataTypes.createStructField("name",DataTypes.StringType,true), DataTypes.createStructField("age",DataTypes.IntegerType,true)) ) val df=sparkSession.createDataFrame(rowsRDD,schema) df.createOrReplaceTempView("person") df.show() val sql="SELECT id, group_age_avg(age) avg_age FROM person GROUP BY id" sparkSession.sql(sql).show() sparkSession.close() }}

三、类型安全的用户自定于聚合函数:Aggregator


1、它是一个接口,需要继承与Aggregator,而Aggregator有3个参数,分别是IN,BUF,OUT,IN表示输入的值是什么,可以是一个自定类对象包含多个值,也可以是单个值,BUF就是需要用来缓存值使用的,如果需要缓存多个值也需要定义一个对象,而返回值也可以是一个对象返回多个值,需要实现的方法有:

package com.udf import org.apache.spark.sql.Encoderimport org.apache.spark.sql.expressions.Aggregator case class DataBuf(var sum:Double,var count:Int)object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{ /** * 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的 * @return */ override def zero: DataBuf = ??? /** * reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b * @param b * @param a * @return */ override def reduce(b: DataBuf, a: Int): DataBuf = ??? /** * merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并, * 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中 * @param b1 * @param b2 * @return */ override def merge(b1: DataBuf, b2: DataBuf): DataBuf = ??? /** * finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值 * 返回值可以是一个对象 * @param reduction * @return */ override def finish(reduction: DataBuf): Double = ??? /** * 缓冲数据编码方式 * @return */ override def bufferEncoder: Encoder[DataBuf] = ??? /** * 最终数据输出编码方式 * @return */ override def outputEncoder: Encoder[Double] = ???}

2、具体实现如下代码所示:

package com.udf import java.math.BigDecimal import org.apache.spark.sql.{Encoder, Encoders}import org.apache.spark.sql.expressions.Aggregatorcase class DataBuf(var sum:Double,var count:Int)object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{ /** * 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的 * @return */ override def zero: DataBuf = DataBuf(0.0,0) /** * reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b * @param b * @param a * @return */ override def reduce(b: DataBuf, a: Int): DataBuf = { b.count+=1 b.sum+=a.toDouble b } /** * merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并, * 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中 * @param b1 * @param b2 * @return */ override def merge(b1: DataBuf, b2: DataBuf): DataBuf = { b1.sum+=b2.sum b1.count+=b2.count b1 } /** * finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值 * 返回值可以是一个对象 * @param reduction * @return */ override def finish(reduction: DataBuf): Double = { val bd = new BigDecimal(reduction.sum/reduction.count.toDouble) bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数 } /** * 缓冲数据编码方式,如果Encoder中指定的类型时对象,则设置为product,如果是具体的类型,则需设置为具体的类型 * @return */ override def bufferEncoder: Encoder[DataBuf] = Encoders.product /** * 最终数据输出编码方式,如果Encoder中指定的类型,则设置为具体的类型,比如Double则设置为scalaDouble * @return */ override def outputEncoder: Encoder[Double] = Encoders.scalaDouble}

3、而使用此聚合函数就不能通过注册函数来使用了,需要通过Dataset对象的select来使用,如下图所示:


执行结果如下图所示:



因此无类型的用户自定于聚合函数:UserDefinedAggregateFunction和类型安全的用户自定于聚合函数:Aggregator之间的区别是


(1)UserDefinedAggregateFunction不能够带类型而Aggregator是可以带类型的。


(2)使用方法不同UserDefinedAggregateFunction通过注册可以在DataFram的sql语句中使用,而Aggregator必须是在Dataset上使用。


四、开窗函数的使用


1、在Spark 1.5.x版本以后,在Spark SQL和DataFrame中引入了开窗函数,其中比较常用的开窗函数就是row_number该函数的作用是根据表中字段进行分组,然后根据表中的字段排序;其实就是根据其排序顺序,给组中的每条记录添加一个序号;且每组的序号都是从1开始,可利用它的这个特性进行分组取top-n。它是放在select子句中的,其格式为:

ROW_NUMBER() OVER (PARTITION BY area ORDER BY click_count DESC) rank

首先可以,在SELECT查询时,使用row_number()函数,其次row_number()函数后面先跟上OVER关键字,然后括号中,是PARTITION BY,也就是说根据哪个字段进行分组,其次是可以用ORDER BY进行组内排序, 然后row_number()就可以给每个组内的行,一个组内行号,然后rank就是每一组的行号


2、使用方法的sql语句为:

SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc

意思是在sql语句中加一个rank字段,该字段记录了以id为分组,在组内按照age升序排序,并记录行号,最后先按照id降序排序,如果id相同则按照rank降序排序

3、代码如下:

package com.udf import org.apache.spark.SparkConfimport org.apache.spark.sql.{Row, SparkSession}import org.apache.spark.sql.types.{DataTypes, StructField} import scala.util.Random object AppUdf { def main(args:Array[String]):Unit={ val conf=new SparkConf().setAppName("AppUdf").setMaster("local") val sparkSession=SparkSession.builder().config(conf).getOrCreate() //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类 //第三个参数是返回类型 sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType) //UDAF不用设置返回类型,因此使用两个参数即可 sparkSession.udf.register("group_age_avg",new AvgAge) var rows=Seq[Row]() val random=new Random() for(i <- 0 until 10){ val name="name"+i val age=random.nextInt(30)%15+15 val row=Row(random.nextInt(2),name,age) rows +:=row } val rowsRDD=sparkSession.sparkContext.parallelize(rows) val schema=DataTypes.createStructType(Array[StructField]( DataTypes.createStructField("id",DataTypes.IntegerType,true), DataTypes.createStructField("name",DataTypes.StringType,true), DataTypes.createStructField("age",DataTypes.IntegerType,true)) ) val df=sparkSession.createDataFrame(rowsRDD,schema) df.createOrReplaceTempView("person") df.show() val sql="SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc" sparkSession.sql(sql).show() sparkSession.close() }}

输出结果如下:

欢迎点赞+收藏+转发朋友圈素质三连


文章不错?点个【在看】吧! 👇


: . Video Mini Program Like ,轻点两下取消赞 Wow ,轻点两下取消在看

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存