【分散処理】PySpark ~ UDFの各定義方法でのサンプル ~

■ はじめに

https://dk521123.hatenablog.com/entry/2020/05/20/195621

で、PySpark の UDF (User Defined Function) 定義方法
について、扱ったが
Udacity(ユダシティ)の Freeコース「Spark」で
別の方法を取り扱っていた。

https://www.udacity.com/course/learn-spark-at-udacity--ud2002

今回は、整理も含めてメモしておく

目次

【1】udf関数から取り込む
【2】デコレータを利用する方法
 補足:デコレータでの使用上の注意
【3】spark.udf.register() で登録する

【1】udf関数から取り込む

from pyspark import SparkContext
from pyspark.sql import SparkSession

from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
# ユーザ定義関数 UDF をインポート
from pyspark.sql.functions import udf

# 定義したい関数
def say_hello(name):
  return "Hello, {}!".format(name)

spark_context = SparkContext()
spark = SparkSession(spark_context)

rdd = spark_context.parallelize([
  (1, 'Mike', 32, 'Sales'),
  (2, 'Tom', 20, 'IT'),
  (3, 'Sam', 32, 'Sales'),
  (4, 'Kevin', 30, 'Human resources'),
  (5, 'Bob', 30, 'IT'),
  (6, 'Alice', 20, 'Banking'),
  (7, 'Carol', 30, 'IT'),
])
schema = StructType([
  StructField('id', IntegerType(), False),
  StructField('name', StringType(), False),
  StructField('age', IntegerType(), False),
  StructField('job', StringType(), False),
])
data_frame = spark.createDataFrame(rdd, schema)
# 関数say_helloを定義。その名前をhello_worldと命名。
hello_world = udf(say_hello, StringType())

# UDFを使用
data_frame.select(hello_world('name')).show()

出力結果

+---------------+
|say_hello(name)|
+---------------+
|   Hello, Mike!|
|    Hello, Tom!|
|    Hello, Sam!|
|  Hello, Kevin!|
|    Hello, Bob!|
|  Hello, Alice!|
|  Hello, Carol!|
+---------------+

【2】デコレータを利用する方法

from pyspark import SparkContext
from pyspark.sql import SparkSession

from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf
from pyspark.sql.functions import col

# デコレータ
@udf(returnType=StringType())
def say_hello(name):
  return "Hello, {}!".format(name)


spark_context = SparkContext()
spark = SparkSession(spark_context)

rdd = spark_context.parallelize([
  (1, 'Mike', 32, 'Sales'),
  (2, 'Tom', 20, 'IT'),
  (3, 'Sam', 32, 'Sales'),
  (4, 'Kevin', 30, 'Human resources'),
  (5, 'Bob', 30, 'IT'),
  (6, 'Alice', 20, 'Banking'),
  (7, 'Carol', 30, 'IT'),
])
schema = StructType([
  StructField('id', IntegerType(), False),
  StructField('name', StringType(), False),
  StructField('age', IntegerType(), False),
  StructField('job', StringType(), False),
])
data_frame = spark.createDataFrame(rdd, schema)

# 出力結果は例1と同じなので省略
data_frame.select(say_hello(col("name"))).show()

補足:デコレータでの使用上の注意

PySpark ~ UDF の使用上の注意 ~
https://dk521123.hatenablog.com/entry/2021/05/20/095706

より抜粋
~~~~~~~~~~~~~
「2)デコレータを利用する方法」は、使用する関数が完全にUDFとして使用する
関数を通常で使用しても、返却される値は、Objectで返ってしまうため。
~~~~~~~~~~~~~
詳細は、上記の関連記事を参照のこと。

【3】spark.udf.register() で登録する

import datetime
from pyspark import SparkContext
from pyspark.sql import SparkSession

from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf
from pyspark.sql.functions import col

def to_age(input_date):
  birth_date = datetime.datetime.strptime(input_date, '%Y-%m-%d')
  today = datetime.date.today()
  return today.year - birth_date.year - \
    ((today.month, today.day) < (birth_date.month, birth_date.day))

spark_context = SparkContext()
spark = SparkSession(spark_context)

rdd = spark_context.parallelize([
  (1, 'Mike', '1976-11-11', 'Sales'),
  (2, 'Tom', '2001-07-21', 'IT'),
  (3, 'Sam', '1956-03-05', 'Sales'),
  (4, 'Kevin', '2011-08-14', 'Human resources'),
  (5, 'Bob', '2003-12-28', 'IT'),
  (6, 'Alice', '2000-05-16', 'Banking'),
  (7, 'Carol', '1982-06-09', 'IT'),
])
schema = StructType([
  StructField('id', IntegerType(), False),
  StructField('name', StringType(), False),
  StructField('birth_date', StringType(), False),
  StructField('job', StringType(), False),
])
data_frame = spark.createDataFrame(rdd, schema)
data_frame.createOrReplaceTempView("employee")

spark.udf.register("get_age", lambda x: to_age(x))
df = spark.sql('''
SELECT
  *,
  get_age(birth_date) AS age
FROM
  employee 
''')
df.show()

出力結果

+---+-----+----------+---------------+---+
| id| name|birth_date|            job|age|
+---+-----+----------+---------------+---+
|  1| Mike|1976-11-11|          Sales| 44|
|  2|  Tom|2001-07-21|             IT| 19|
|  3|  Sam|1956-03-05|          Sales| 65|
|  4|Kevin|2011-08-14|Human resources|  9|
|  5|  Bob|2003-12-28|             IT| 17|
|  6|Alice|2000-05-16|        Banking| 21|
|  7|Carol|1982-06-09|             IT| 38|
+---+-----+----------+---------------+---+

関連記事

PySpark ~ 環境構築編 ~
https://dk521123.hatenablog.com/entry/2019/11/14/221126
PySpark ~ 入門編 ~
https://dk521123.hatenablog.com/entry/2021/04/03/004254
PySpark ~ ユーザ定義関数 UDF 編 ~
https://dk521123.hatenablog.com/entry/2020/05/20/195621
PySpark ~ UDF の使用上の注意 ~
https://dk521123.hatenablog.com/entry/2021/05/20/095706
PySparkで入力ファイル名を取得するには
https://dk521123.hatenablog.com/entry/2021/04/12/145133
Apache Hadoop ~ 入門編 ~
https://dk521123.hatenablog.com/entry/2019/09/15/100727