【分散処理】PySpark ~ UDF の使用上の注意 ~

■ はじめに

https://dk521123.hatenablog.com/entry/2020/05/20/195621

の続き。

PySpark の UDF (User Defined Function) で
ミスった点や注意点などをあげておく。

目次

【1】メモリ消費について
【2】デコレータによる実装方法に関する注意点
【3】引数に関する注意点
 補足:型変換(キャスト)したい場合
【4】トラブルが発生した際のデバッグについて

【1】メモリ消費について

* 避けられるなら使わない方法で実装した方が無難って話。

理由
https://aws.amazon.com/jp/blogs/news/optimize-memory-management-in-aws-glue/

より抜粋
~~~~~~~~
PySpark ユーザー定義関数 (UDF):
PySpark UDF を使用すると、エグゼキュータのメモリにコストがかかる場合があります。
これは、Spark エグゼキュータ JVM と Python インタープリター間でデータを交換するときに、
データをシリアル化/非シリアル化する必要があるためです。
 Python インタープリターは、シリアル化されたデータを Spark エグゼキュータの
オフヒープメモリで処理する必要があります。
大きなレコードやネストされたレコードがあるデータセットの場合、
または複雑な UDF を使用している場合、この処理は大量のオフヒープメモリを消費し、
Yarn のオーバヘッドメモリの超過が原因で OOM 例外が発生する可能性があります。
その場合、次のようなエラーメッセージを受け取ります。

ERROR YarnClusterScheduler: Lost executor 1 on ip-xxx:
Container killed by YARN for exceeding memory limits.5.5 GB of 5.5 GB physical memory used.
Consider boosting spark.yarn.executor.memoryOverhead
~~~~~~~~

【2】デコレータによる実装方法に関する注意点

デコレータによる実装方法 や「1)udf関数から取り込む」など
については、以下の関連記事を参照のこと。

https://dk521123.hatenablog.com/entry/2020/05/20/195621
https://dk521123.hatenablog.com/entry/2021/05/27/100132

結論から言うと、
1)udf関数から取り込む
2)デコレータを利用する方法
の使い分けは、以下がいいのかなっと。

「1)udf関数から取り込む」は、使用する関数が通常でも使いたい場合
「2)デコレータを利用する方法」は、使用する関数が完全にUDFとして使用する場合

理由

「2)デコレータを利用する方法」の関数を通常で使用しても
返却される値は、Objectで返ってしまうため。
(以下のサンプルを参照のこと)

サンプル

from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType
from pyspark.sql.functions import input_file_name
from pyspark.sql.functions import udf

import os

# ファイル名のみ取得する関数
@udf(returnType=StringType())
def to_file_name(full_path):
  return os.path.basename(full_path)

spark_context = SparkContext()
spark = SparkSession(spark_context)

# ファイルをRDDとして取得
rdd = spark_context.textFile("./test.txt")
schema = StructType([
  StructField('id', StringType(), False),
  StructField('name', StringType(), False),
  StructField('heigh', StringType(), False),
  StructField('birth_date', StringType(), False),
])
temp_rdd = rdd.map(lambda row: row.replace('\r', '').replace('\0', '').strip())

data_frame = spark.read.csv(temp_rdd, sep="|", header=False)
df = data_frame.filter(data_frame["_c0"] == "100").select("_c1") \
  .withColumn("file_name", to_file_name(input_file_name()))
df.show()

# ★注目★ str ではなく object で返ってきてしまうため
# 出力結果 => Column<b'to_file_name(s3://xxx/yyy/xxx.csv)'>
print(to_file_name("s3://xxx/yyy/xxx.csv"))

test.txt

100|Mike|179.6|2018-03-20 10:41:20
101|Sam|167.9|2018-03-03 11:32:34
102|Kevin|189.2|2018-01-28 20:20:11
103|Mike|179.6|2018-03-20 10:41:20
104|Sam|167.9|2018-03-03 11:32:34
105|Kevin|189.2|2018-01-28 20:20:11
106|Mike|179.6|2018-03-20 10:41:20
107|Sam|167.9|2018-03-03 11:32:34
108|Kevin|189.2|2018-01-28 20:20:11

【3】引数に関する注意点

https://dk521123.hatenablog.com/entry/2021/04/12/145133

の応用として、ファイル名の一部の文字列を抽出するコード
(例:ファイル名「hello_world_20191012.csv」 => 「20191012」を取得したい)
を実装した場合、以下の「NGコード(一部)」のように呼び出したら、
エラー「AnalysisException」が発生した。

NGコード(一部)

# 呼び出し元のUDF
@udf(returnType=StringType())
def get_file_date(full_path, pattern):
  file_name = os.path.basename(full_path)
  # e.g 'hello_world_20210512.csv' => ['20210512']
  results = re.findall(pattern, file_name)
  return results[0]

# 第二引数「pattern」に対して、文字列「r'\d{8}'」を設定
#  => 結論からいうと、これじゃーダメ。
data_frame = data_frame \
  .withColumn(
    'file_date', get_file_date(input_file_name(), r'\d{8}'))

例外内容(一部抜粋)

pyspark.sql.utils.AnalysisException:
"cannot resolve '`\\d{8}`' given input columns: [id, name, job];;\n'Project [id#19, name#20, job#21,
 get_file_date(input_file_name(), '\\d{8}) AS file_date#26]\n+- LogicalRDD [id#19, name#20, job#21], false\n"

解決方法

https://stackoverflow.com/questions/48443892/pyspark-using-udf-with-arguments-to-create-a-new-column/48443893

を参考にした。
 => 引数には、let() を使って呼び出す。(以下「OKコード(フルコード)」を参照)

OKコード(フルコード)

from pyspark import SparkContext
from pyspark.sql import SparkSession

from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType
from pyspark.sql.functions import input_file_name
from pyspark.sql.functions import udf
from pyspark.sql.functions import lit

import os
import re


@udf(returnType=StringType())
def get_file_date(full_path, pattern):
  file_name = os.path.basename(full_path)
  # e.g 'hello_world_20210512.csv' => ['20210512']
  results = re.findall(pattern, file_name)
  return results[0]

spark_context = SparkContext()
spark = SparkSession(spark_context)

data_frame = spark.read.csv(
  "hello_world_20210512.csv", header=False)

data_frame = data_frame \
  .withColumnRenamed("_c0", "id") \
  .withColumnRenamed("_c1", "name") \
  .withColumnRenamed("_c2", "job")

# ★注目★
# 第二引数「pattern」に対して、「lit(r'\d{8}')」を設定 => OK
data_frame = data_frame \
  .withColumn(
    'file_date', get_file_date(input_file_name(), lit(r'\d{8}')))

data_frame.show()

入力ファイル「hello_world_20210512.csv

x0001,Mike,Sales
x0002,Tom,IT
x0003,Sam,Sales
x0004,Kevin,Human resources
x0005,Bob,IT
x0006,Alice,Banking
x0007,Carol,IT
x0008,Tom,Banking
x0009,Mike,IT
x0010,Bob,Sales

出力結果 (file_date に注目)

+-----+-----+---------------+---------+
|   id| name|            job|file_date|
+-----+-----+---------------+---------+
|x0001| Mike|          Sales| 20210512|
|x0002|  Tom|             IT| 20210512|
|x0003|  Sam|          Sales| 20210512|
|x0004|Kevin|Human resources| 20210512|
|x0005|  Bob|             IT| 20210512|
|x0006|Alice|        Banking| 20210512|
|x0007|Carol|             IT| 20210512|
|x0008|  Tom|        Banking| 20210512|
|x0009| Mike|             IT| 20210512|
|x0010|  Bob|          Sales| 20210512|
+-----+-----+---------------+---------+

補足:型変換(キャスト)したい場合

# cast([データ型]Type())を使う
data_frame = data_frame \
  .withColumn(
    'xxxxx', get_xxxx(input_file_name(), lit(value).cast(StringType())))

【4】トラブルが発生した際のデバッグについて

* UDF内でエラーが起こると結構分かりづらいエラーになる可能性があるので
 怪しかったら、一時的にでも、print()文などを埋め込んで
 原因を調査するのもいいかも。
 => ただ、AWS Glue 内だと(デフォルト設定では?)ログ出力されなかったので、
  ローカル環境を作って、デバッグを行った。

サンプル(デバッグ時)

@udf(returnType=StringType())
def get_file_date(full_path, pattern):
  print("******************************************")
  print(f"Enter get_file_date. Parameters are [{full_path}], [{pattern}]")
  print("******************************************")
  file_name = os.path.basename(full_path)
  # e.g 'hello_world_20210512.csv' => ['20210512']
  results = re.findall(pattern, file_name)
  # ★ここで len(result) == 0 だったので気が付かなかった★
  return results[0]

# ・・・略・・・

data_frame = spark.read.csv(
  "hello_world_20210512.csv", header=False)

schema = StructType([
  StructField('id', StringType(), True),
  StructField('name', StringType(), True),
  StructField('job', StringType(), True),
])
# 新たに create しているからか原因なのか input_file_name() で
# ファイルパスが取得できずに空になる
#(ただ、このコードが環境によっては大丈夫なこともある、、、)
data_frame = spark.createDataFrame(data_frame.rdd, schema)

# ・・・略・・・

data_frame = data_frame \
  .withColumn(
    'file_date', get_file_date(input_file_name(), lit(r'\d{8}')))

# ・・・略・・・

関連記事

PySpark ~ 環境構築編 ~
https://dk521123.hatenablog.com/entry/2019/11/14/221126
PySpark ~ 入門編 ~
https://dk521123.hatenablog.com/entry/2021/04/03/004254
PySpark ~ ユーザ定義関数 UDF 編 ~
https://dk521123.hatenablog.com/entry/2020/05/20/195621
PySparkで入力ファイル名を取得するには
https://dk521123.hatenablog.com/entry/2021/04/12/145133