Extend Spark with
your own Python logic
Sometimes the 300+ built-in functions from Module 9 aren't enough — you need custom business logic, a machine learning model, or a complex calculation that only Python can express. This module teaches you how to write, register, and optimize User Defined Functions (UDFs) — and just as importantly, when not to use them.
Scalar UDFs, Registration & Usage
A scalar Python UDF takes one row's worth of column values and returns a single value — just like a normal Python function, but wrapped so Spark can call it for every row.
F.udf(), and Spark calls it once per row across every partition, in parallel.
Typical use cases: applying a custom business rule that can't be expressed with
when/otherwise, calling a Python library that has no Spark equivalent (e.g. geopy, a regex with lookaheads, a custom hashing scheme), or running a pre-trained ML model's .predict() method per row.
F.udf()F.udf(function, returnType). You must specify the return type — Spark cannot infer it from Python type hints alone (though newer PySpark versions can use type hints with @udf).
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, DoubleType
spark = SparkSession.builder.appName("Module12").getOrCreate()
data = [(1, "alice smith", 52000.0),
(2, "BOB JONES", 61000.0),
(3, "Carol Lee", 73000.0)]
df = spark.createDataFrame(data, ["id", "name", "salary"])
# Step 1: plain Python function
def title_case(name: str) -> str:
return name.title() if name else None
# Step 2: wrap it as a UDF — return type is REQUIRED
title_case_udf = F.udf(title_case, StringType())
# Step 3: use it like any column expression
df.withColumn("name_clean", title_case_udf(F.col("name"))).show()
| id | name | salary | name_clean |
|---|---|---|---|
| 1 | alice smith | 52000.0 | Alice Smith |
| 2 | BOB JONES | 61000.0 | Bob Jones |
| 3 | Carol Lee | 73000.0 | Carol Lee |
@udf Decorator Style@F.udf(returnType) as a decorator directly on the function definition. This is purely stylistic — both forms produce the same UserDefinedFunction object.
@F.udf(returnType=DoubleType())
def apply_bonus(salary: float) -> float:
# Custom business rule: 10% bonus, capped at 7000
bonus = salary * 0.10
return round(min(bonus, 7000.0), 2)
df.withColumn("bonus", apply_bonus(F.col("salary"))).show()
# +---+-----------+-------+------+
# |id |name |salary |bonus |
# +---+-----------+-------+------+
# |1 |alice smith|52000.0|5200.0|
# |2 |BOB JONES |61000.0|6100.0|
# |3 |Carol Lee |73000.0|7000.0| ← capped
# +---+-----------+-------+------+
spark.sql(...) or selectExpr), register it on the session with spark.udf.register(name, function, returnType). This creates a temporary function available for the lifetime of the SparkSession.
# Register the same Python function for use in SQL
spark.udf.register("title_case_sql", title_case, StringType())
df.createOrReplaceTempView("employees")
spark.sql("""
SELECT id, title_case_sql(name) AS name_clean, salary
FROM employees
""").show()
# Also usable inside selectExpr / DataFrame expressions:
df.selectExpr("id", "title_case_sql(name) as name_clean").show()
F.udf(fn, type) → returns a Column-expression UDF for the DataFrame API only.spark.udf.register("name", fn, type) → registers for both SQL and returns a Column-expression UDF you can also use in withColumn.
None into your Python function whenever the column value is null — your function must handle this explicitly, or you'll get a TypeError at runtime (which surfaces as a confusing executor-side Python exception). UDFs can also accept multiple columns as multiple arguments.
@F.udf(returnType=StringType())
def full_label(name, salary):
# ALWAYS guard against None — Spark will pass it for nulls
if name is None or salary is None:
return "UNKNOWN"
return f"{name.upper()} earns ${salary:,.2f}"
df.withColumn("label", full_label("name", "salary")).show(truncate=False)
None inside a UDF causes a Py4JJavaError wrapping a Python TypeError — and the stack trace points to the JVM, not your Python line, making it harder to debug. Always null-check first.
UDF Internals: JVM ↔ Python Communication
This is the most important conceptual section in the module — understanding exactly what happens "under the hood" when a Python UDF runs explains every performance characteristic you'll see later.
df.filter(...) or use a built-in function, PySpark just builds a logical plan description and sends it to the JVM via Py4J (a library that lets Python call Java objects). The JVM does all the actual data processing — your Python process never touches the data.
A Python UDF breaks this model. The JVM cannot execute your Python function — it doesn't have a Python interpreter. So for every partition, Spark must launch a separate Python worker process on the executor, and the JVM and that Python process must exchange data over a local socket.
| Step | What Happens | Cost |
|---|---|---|
| 1. Launch worker | Executor JVM spawns (or reuses) a python worker process via a socket pipe | Process startup (amortized if reused) |
| 2. Serialize rows | JVM rows for the relevant columns are pickled (Python's serialization format) and written to the socket | CPU + memory for every row |
| 3. Deserialize in Python | Python worker reads bytes, unpickles into Python objects (int, str, etc.) | CPU per row |
| 4. Run your function | Your UDF executes, row by row, in the Python interpreter (GIL-bound, no JIT) | Pure Python speed |
| 5. Serialize result | Return values are pickled again and written back over the socket | CPU + memory per row |
| 6. Deserialize in JVM | JVM reads bytes back, converts into InternalRow / UnsafeRow format | CPU per row |
F.upper(), which is compiled Java/Scala bytecode running directly on Spark's in-memory UnsafeRow binary format — zero serialization, zero process hops.
# Built-in: Catalyst sees "upper" and can optimize freely
df.filter(F.upper(F.col("name")) == "ALICE").explain()
# == Physical Plan ==
# *(1) Filter (upper(name) = ALICE) ← fused into codegen
# UDF: Catalyst sees an opaque "PythonUDF" node
df.filter(title_case_udf(F.col("name")) == "Alice").explain()
# == Physical Plan ==
# *(2) Filter (pythonUDF0#34 = Alice)
# +- BatchEvalPython [title_case_udf(name)] ... ← separate, non-fused stage
BatchEvalPython (or the older ArrowEvalPython for Pandas UDFs) in an explain() plan, that's your signal: Spark is crossing the process boundary at that point in the plan. Minimizing the number and "width" (rows × columns) of these nodes is a key optimization goal.
spark.python.worker.reuse=true). Even so, each reused worker still pays the per-row pickle/unpickle cost described above — reuse only avoids repeated process-startup overhead, not serialization overhead.
| Config | Purpose |
|---|---|
spark.python.worker.reuse | Reuse Python worker processes across tasks (default: true) |
spark.executorEnv.PYTHONPATH | Ensure custom modules are importable inside the worker |
spark.python.worker.memory | Memory threshold before the Python worker spills to disk during aggregation-like UDF operations |
Scalar Pandas UDF
Pandas UDFs (also called "vectorized UDFs") operate on entire columns of data as pandas Series at once, instead of one Python value at a time — using Apache Arrow to move whole batches efficiently.
@F.pandas_udf(returnType), receives one or more pandas.Series (a chunk/batch of column values, typically thousands of rows) and must return a pandas.Series of the same length. Internally, Spark batches rows into Arrow record batches, hands a whole batch to your function, and gets a whole batch back — instead of one Python object per call.
pd.Series in and returns pd.Series out. Inside, you can use vectorized pandas/NumPy operations instead of Python loops — this is where the real speed-up comes from, on top of the reduced serialization cost.
import pandas as pd
from pyspark.sql.types import DoubleType
data = [(1, 52000.0), (2, 61000.0), (3, 73000.0), (4, 45000.0)]
df = spark.createDataFrame(data, ["id", "salary"])
@F.pandas_udf(DoubleType())
def pandas_bonus(salary: pd.Series) -> pd.Series:
# Vectorized: operates on the WHOLE batch at once, no Python loop
bonus = salary * 0.10
return bonus.clip(upper=7000.0).round(2)
df.withColumn("bonus", pandas_bonus(F.col("salary"))).show()
# +---+-------+------+
# |id |salary |bonus |
# +---+-------+------+
# |1 |52000.0|5200.0|
# |2 |61000.0|6100.0|
# |3 |73000.0|7000.0|
# |4 |45000.0|4500.0|
# +---+-------+------+
pd.Series, and pandas/NumPy lets you operate on them together with normal vectorized arithmetic and comparisons.
from pyspark.sql.types import StringType
@F.pandas_udf(StringType())
def salary_band(salary: pd.Series, bonus: pd.Series) -> pd.Series:
total = salary + bonus
# np.select is the vectorized equivalent of when/otherwise
import numpy as np
conditions = [total < 50000, total < 70000]
choices = ["Junior", "Mid"]
return pd.Series(np.select(conditions, choices, default="Senior"))
result = df.withColumn("bonus", pandas_bonus("salary")) \\
.withColumn("band", salary_band("salary", "bonus"))
result.show()
withColumn, select, even spark.udf.register). The difference is entirely in the function body: Series → Series with vectorized ops, instead of value → value with a Python loop hidden inside Spark's row iteration.
Grouped Map UDF — applyInPandas
Sometimes you need to transform an entire group of rows together — not just one row, and not a simple aggregate. Grouped Map UDFs hand each group to your function as its own pandas DataFrame.
df.groupBy(...).applyInPandas(func, schema) when your transformation needs to see all rows of a group at once — e.g. normalizing values relative to the group's mean, fitting a small per-group model, ranking within a group using pandas logic, or any operation that's naturally "split-apply-combine."
Your function receives a
pd.DataFrame for one group's rows and must return a pd.DataFrame — the output schema (column names and types) must be declared up front via a StructType or DDL string, because Spark needs it before any data has been processed.
groupBy().applyInPandas() like sorting mail into per-department bins, then handing each entire bin to that department's specialist (your pandas function) to process however they like, before the results are merged back into one outgoing pile. Each department works independently and can return a different number of items than it received.
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
emp_data = [("Eng", 1, 90000.0), ("Eng", 2, 110000.0), ("Eng", 3, 130000.0),
("Sales", 4, 60000.0), ("Sales", 5, 75000.0)]
emp_df = spark.createDataFrame(emp_data, ["dept", "id", "salary"])
# Output schema MUST be declared up front
out_schema = StructType([
StructField("dept", StringType()),
StructField("id", IntegerType()),
StructField("salary", DoubleType()),
StructField("salary_zscore", DoubleType())
])
def zscore_within_group(pdf: pd.DataFrame) -> pd.DataFrame:
# pdf = ALL rows for ONE department, as a pandas DataFrame
mean = pdf["salary"].mean()
std = pdf["salary"].std()
pdf["salary_zscore"] = ((pdf["salary"] - mean) / std).round(3)
return pdf
result = emp_df.groupBy("dept").applyInPandas(zscore_within_group, schema=out_schema)
result.orderBy("dept", "id").show()
| dept | id | salary | salary_zscore |
|---|---|---|---|
| Eng | 1 | 90000.0 | -1.0 |
| Eng | 2 | 110000.0 | 0.0 |
| Eng | 3 | 130000.0 | 1.0 |
| Sales | 4 | 60000.0 | -0.707 |
| Sales | 5 | 75000.0 | 0.707 |
Grouped Aggregate Pandas UDF
Like F.avg() or F.sum(), but custom: a Grouped Aggregate Pandas UDF takes a Series per group and reduces it down to a single scalar value — usable inside .agg().
@F.pandas_udf(returnType), Series → scalar), but PySpark detects from how it's used — inside .agg(...) after a groupBy — that it should reduce each group's Series to one value, rather than map element-wise.
This is the simplest way in modern PySpark to write a "custom aggregate function" without dealing with the full UDAF lifecycle (covered in 12.4).
F.avg() is a calculator with a built-in "average" button, a Grouped Aggregate Pandas UDF is handing the calculator's input tape to a statistician who applies whatever formula you taught them, then hands back one number — repeated once per group.
weighted_avg. Here we compute one — total value weighted by hours — per department, using a Grouped Aggregate Pandas UDF directly inside .agg().
sales_data = [("Eng", 100.0, 8), ("Eng", 200.0, 4), ("Eng", 150.0, 6),
("Sales", 300.0, 10), ("Sales", 100.0, 2)]
sales_df = spark.createDataFrame(sales_data, ["dept", "value", "hours"])
@F.pandas_udf(DoubleType())
def weighted_avg(value: pd.Series, weight: pd.Series) -> float:
# Receives ALL values for one group as Series, returns ONE float
return float((value * weight).sum() / weight.sum())
sales_df.groupBy("dept").agg(
weighted_avg(F.col("value"), F.col("hours")).alias("weighted_value")
).show()
# +-----+--------------+
# |dept |weighted_value|
# +-----+--------------+
# |Eng |150.0 |
# |Sales|266.666... |
# +-----+--------------+
@F.pandas_udf can produce either a scalar (element-wise) Pandas UDF or a grouped-aggregate Pandas UDF — what determines the behavior is whether you call it inside .withColumn()/.select() (scalar, Series→Series) or inside .groupBy().agg() (aggregate, Series→scalar). The function's return type annotation should match: a single value (e.g. float) for aggregates.
| Need | Recommendation |
|---|---|
| sum, avg, count, min, max, stddev | Built-in — already optimized, no UDF needed |
| Weighted average, custom percentile logic, domain-specific statistic | Grouped Aggregate Pandas UDF |
| Complex multi-step stateful aggregation with merge semantics | Full UDAF (12.4) — more control, more code |
Iterator UDF & Map Iterator UDF
When a UDF needs to do expensive setup work once — like loading a large ML model — before processing many batches, Iterator-style Pandas UDFs are the right tool.
Iterator[pd.Series] → Iterator[pd.Series], so any code before the loop over the iterator runs once per Python worker, not once per batch.
for batch in iterator loop runs once per Arrow batch.
from typing import Iterator
@F.pandas_udf(DoubleType())
def predict_with_model(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
# --- runs ONCE per Python worker, not once per batch ---
model = load_pretrained_model() # expensive: e.g. 500MB model file
# --- runs once per Arrow batch ---
for feature_batch in batch_iter:
predictions = model.predict(feature_batch.to_numpy().reshape(-1, 1))
yield pd.Series(predictions)
scored = df.withColumn("prediction", predict_with_model(F.col("salary")))
Iterator[Tuple[pd.Series, ...]] → Iterator[pd.Series] — you unpack the tuple inside the loop.
mapInPandas — DataFrame-to-DataFrame Iteratordf.mapInPandas(func, schema) is the DataFrame-level cousin: func has signature Iterator[pd.DataFrame] → Iterator[pd.DataFrame]. Unlike applyInPandas, there's no grouping — each partition's data arrives as a stream of pandas DataFrame batches, and you can change the number of output columns and rows freely (filter, expand, reshape).
def filter_and_flag(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
for pdf in iterator:
pdf = pdf[pdf["salary"] > 50000].copy() # row filter
pdf["high_earner"] = True # new column
yield pdf
out_schema = "id INT, salary DOUBLE, high_earner BOOLEAN"
df.mapInPandas(filter_and_flag, schema=out_schema).show()
| You need... | Use |
|---|---|
| One value out per row, vectorized math | @pandas_udf scalar (Series → Series) |
| Transform/reshape an entire group together | groupBy().applyInPandas() |
| Custom aggregate value per group | @pandas_udf used in .agg() (Series → scalar) |
| Expensive one-time setup (ML model, connection) | Iterator of Series UDF |
| Arbitrary row filter/reshape, no grouping, with setup cost | mapInPandas |
Temporary & Permanent SQL UDFs
Once a UDF is registered, it can be called inside spark.sql() just like a built-in SQL function. Spark also supports persisting UDFs into the metastore as permanent catalog functions.
spark.udf.registerspark.udf.register("fn_name", python_fn, returnType) creates a session-scoped temporary function. It lives only as long as the SparkSession (or, in Databricks notebooks, the attached cluster session) and is not visible to other sessions or persisted anywhere.
def years_of_service(hire_year: int) -> int:
return 2026 - hire_year
spark.udf.register("years_of_service", years_of_service, IntegerType())
spark.sql("""
SELECT id, hire_year, years_of_service(hire_year) AS tenure
FROM employees_with_hire_year
""").show()
CREATE FUNCTION my_db.years_of_service
AS 'com.company.udf.YearsOfServiceUDF'
USING JAR 's3://my-bucket/udf-jars/company-udfs.jar';
CREATE FUNCTION ... USING JAR is designed for JVM (Java/Scala) UDFs packaged as JARs — it doesn't directly support arbitrary Python functions, because the metastore stores a class reference the JVM can load. For Python logic you want available "permanently," the common patterns are: (1) register it at the start of every session/notebook via a shared init script, or (2) on platforms with native support (e.g. Unity Catalog Python UDFs on Databricks), define a SQL-callable Python function with CREATE FUNCTION ... LANGUAGE PYTHON.
-- Unity Catalog Python UDF (persisted, SQL-callable, governed)
CREATE OR REPLACE FUNCTION main.default.years_of_service(hire_year INT)
RETURNS INT
LANGUAGE PYTHON
AS $$
return 2026 - hire_year
$$;
SELECT id, years_of_service(hire_year) AS tenure FROM employees_with_hire_year;
# List all functions visible in the current catalog/database
for fn in spark.catalog.listFunctions():
if not fn.name.startswith("_"): # skip internal builtins for brevity
print(fn.name, "-", fn.description)
# SQL equivalent
spark.sql("SHOW USER FUNCTIONS").show(truncate=False)
spark.sql("DESCRIBE FUNCTION years_of_service").show(truncate=False)
# Remove a temporary function
spark.sql("DROP TEMPORARY FUNCTION IF EXISTS years_of_service")
SHOW USER FUNCTIONS for an unexpected temporary registration left over from earlier code.
User Defined Aggregate Functions
A UDAF is fundamentally different from a scalar UDF: instead of transforming one row into one value, it combines many rows into one value through a defined lifecycle — and that lifecycle is what lets Spark parallelize it correctly.
sum, avg, count, etc.) works the same conceptual way, because Spark's data is split across many partitions/executors:
| Stage | What Happens | Example for "average" |
|---|---|---|
| initialize | Create an empty "buffer" (running state) per group | buffer = {sum: 0, count: 0} |
| update | For each row in a partition, fold it into the buffer | buffer.sum += x; buffer.count += 1 |
| merge | Combine two partial buffers from different partitions/executors | merged.sum = b1.sum + b2.sum; merged.count = b1.count + b2.count |
| finish / evaluate | Convert the final merged buffer into the output value | return buffer.sum / buffer.count |
Aggregator, with typed zero/reduce/merge/finish methods) is a Scala/Java API. In Python, you have two practical options:
| Approach | How it Maps to the Lifecycle | Best For |
|---|---|---|
| Grouped Aggregate Pandas UDF (12.2.3) | Spark handles update/merge internally by collecting each group's Series before calling your function once — your function is the "finish" step over a full group's data | Most custom aggregates; simplest code |
Aggregator (Scala) called from PySpark via a registered JVM function | You implement zero, reduce, merge, finish explicitly | Maximum performance, reusable across SQL/DataFrame/streaming, when a JVM dev is available |
range aggregate (max − min in one step). Here's a Grouped Aggregate Pandas UDF that implements it — conceptually, pandas handles update+merge for us by giving the whole group's Series, and our function body is the "finish" logic.
readings = [("sensor1", 10.2), ("sensor1", 15.8), ("sensor1", 9.1),
("sensor2", 100.0), ("sensor2", 102.5)]
sensor_df = spark.createDataFrame(readings, ["sensor_id", "value"])
@F.pandas_udf(DoubleType())
def value_range(values: pd.Series) -> float:
# "finish" step: receives the full group as a Series
return float(values.max() - values.min())
sensor_df.groupBy("sensor_id").agg(
value_range("value").alias("value_range")
).show()
# +---------+-----------+
# |sensor_id|value_range|
# +---------+-----------+
# |sensor1 |6.7 |
# |sensor2 |2.5 |
# +---------+-----------+
Apache Arrow & Vectorized Execution
Apache Arrow is the technology that makes Pandas UDFs dramatically faster than row-based Python UDFs. Understanding what it is — and isn't — clarifies why 12.2's UDFs perform the way they do.
UnsafeRow format into Arrow record batches (columnar chunks, typically thousands of rows). These batches are transferred to the Python worker — because the byte layout is shared, the Python side can interpret the same memory as a pandas.DataFrame/pyarrow.Table with minimal copying ("zero-copy" for many types).
"Vectorized execution" then refers to your UDF operating on whole NumPy/pandas arrays at once (e.g.
series * 2) using SIMD-friendly, compiled C code under pandas/NumPy — instead of a Python-level for loop over individual values.
| Aspect | Plain Python UDF | Pandas UDF (Arrow) |
|---|---|---|
| Transfer unit | Row by row, pickled | Batches of rows, Arrow columnar format |
| Python-side type | Individual Python objects (int, str) | pandas.Series / pandas.DataFrame |
| Your code style | Python loop (implicit, per row) | Vectorized pandas/NumPy operations |
| Typical speedup | baseline | often 3×–100× depending on operation |
spark.sql.execution.arrow.pyspark.enabled governs Pandas UDFs and toPandas()/createDataFrame() conversions; batch size controls memory usage vs. overhead trade-offs.
# Enable Arrow-based columnar transfer (default: True in modern Spark)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
# Number of rows per Arrow batch sent to the Python worker
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")
# If a Pandas UDF errors and this is True, Spark silently falls back
# to row-at-a-time execution instead of failing the job
spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "true")
arrow.pyspark.fallback.enabled=true is convenient for correctness (your job won't crash on an Arrow type mismatch), but it can silently degrade performance back to row-at-a-time execution. In performance-critical pipelines, set this to false in lower environments to surface Arrow incompatibilities during testing rather than discovering them as a slow job in production.
ArrowEvalPython in the plan) — no predicate pushdown or codegen fusion through them. Also, some complex nested types (deeply nested structs/maps) have historically had partial or version-dependent Arrow support — always test with your real schema.
UDF Performance & Built-in Alternatives
Bringing it all together: how to measure UDF overhead yourself, and a practical decision framework for choosing the right tool — built-in, Pandas UDF, or plain UDF — every time.
| Cost Source | Plain Python UDF | Pandas UDF | Built-in Function |
|---|---|---|---|
| Serialization (pickle per row) | Yes — major cost | No — Arrow columnar | None |
| Process boundary (JVM ↔ Python) | Yes | Yes | None |
| Catalyst optimization (pushdown, codegen) | Disabled | Disabled | Full |
| Python-level loop overhead | Per-row Python | Vectorized | N/A |
.count() to force execution (actions, not transformations, trigger computation — see Module 4).
import time
big_df = spark.range(5_000_000).withColumn("name", F.concat(F.lit("user_"), F.col("id").cast("string")))
# 1. Built-in function
t0 = time.time()
big_df.withColumn("upper_name", F.upper("name")).count()
print("Built-in: ", round(time.time() - t0, 2), "sec")
# 2. Plain Python UDF
plain_upper = F.udf(lambda s: s.upper(), StringType())
t0 = time.time()
big_df.withColumn("upper_name", plain_upper("name")).count()
print("Plain UDF: ", round(time.time() - t0, 2), "sec")
# 3. Pandas UDF
@F.pandas_udf(StringType())
def pandas_upper(s: pd.Series) -> pd.Series:
return s.str.upper()
t0 = time.time()
big_df.withColumn("upper_name", pandas_upper("name")).count()
print("Pandas UDF:", round(time.time() - t0, 2), "sec")
F.upper(): baseline (1×). Pandas UDF: roughly 2×–5× slower than built-in. Plain Python UDF: roughly 5×–20× slower than built-in — and the gap widens dramatically as row count grows, because the per-row serialization cost scales linearly while built-ins barely add overhead.
explain() to Spot UDF OverheadBatchEvalPython (plain UDFs) or ArrowEvalPython (Pandas UDFs) nodes — and note whether they appear in a tight loop inside a larger pipeline (worse) or just once near the output (less bad).
big_df.withColumn("upper_name", plain_upper("name")).explain()
# == Physical Plan ==
# *(2) Project [id, name, pythonUDF0#42 AS upper_name]
# +- BatchEvalPython [<lambda>(name)], [pythonUDF0#42]
# +- *(1) Project [id, concat(user_, cast(id as string)) AS name]
# +- *(1) Range (0, 5000000, step=1)
| Step | Question | If Yes |
|---|---|---|
| 1 | Does a built-in function or combination (Module 9, 11) already do this? | Use it — stop here |
| 2 | Can the logic be expressed as vectorized pandas/NumPy operations? | Use a Pandas UDF |
| 3 | Does it need expensive one-time setup (model load, connection)? | Iterator of Series / mapInPandas |
| 4 | Is it a custom aggregate? | Grouped Aggregate Pandas UDF |
| 5 | None of the above (rare, tiny dataset, prototype) | Plain Python UDF — but profile it |
Cheat Sheet
Quick reference for every UDF type covered in this module.
| Task | Code |
|---|---|
| Define + use | F.udf(fn, returnType) then df.withColumn("c", my_udf("col")) |
| Decorator form | @F.udf(returnType) above def my_fn(...) |
| Register for SQL | spark.udf.register("name", fn, returnType) |
| Null safety | Always check if x is None first inside the function body |
| Type | Signature | Used Via |
|---|---|---|
| Scalar | Series → Series | withColumn / select |
| Grouped Map | pd.DataFrame → pd.DataFrame | groupBy().applyInPandas(fn, schema) |
| Grouped Aggregate | Series → scalar | groupBy().agg(fn(...)) |
| Iterator of Series | Iterator[Series] → Iterator[Series] | withColumn / select (setup runs once) |
| Map Iterator | Iterator[pd.DataFrame] → Iterator[pd.DataFrame] | df.mapInPandas(fn, schema) |
| Task | Code |
|---|---|
| Temp SQL function | spark.udf.register("name", fn, type) |
| Permanent JVM function | CREATE FUNCTION name AS '...' USING JAR '...' |
| Permanent Python (Unity Catalog) | CREATE FUNCTION name(...) RETURNS ... LANGUAGE PYTHON AS $$ ... $$ |
| List functions | spark.catalog.listFunctions() / SHOW USER FUNCTIONS |
| Drop temp function | DROP TEMPORARY FUNCTION IF EXISTS name |
| Stage | Role |
|---|---|
| initialize | Create empty per-group buffer |
| update | Fold each row into the buffer |
| merge | Combine partial buffers across partitions |
| finish | Convert final buffer to output value |
| Python shortcut | Grouped Aggregate Pandas UDF — Spark handles update/merge, you write "finish" |
| Config | Purpose |
|---|---|
spark.sql.execution.arrow.pyspark.enabled | Enable Arrow columnar transfer (default: true) |
spark.sql.execution.arrow.maxRecordsPerBatch | Rows per Arrow batch (default: 10000) |
spark.sql.execution.arrow.pyspark.fallback.enabled | Fall back to row-at-a-time on Arrow errors |
| Priority | Choice |
|---|---|
| 1 (best) | Built-in functions (Module 9/11) |
| 2 | Pandas UDF (scalar / grouped / iterator) |
| 3 | Grouped Aggregate Pandas UDF for custom aggregates |
| 4 (last resort) | Plain Python UDF — profile and document why |
Test Your Knowledge
5 questions covering Python UDFs, Pandas UDF variants, Arrow, and performance trade-offs.