MODULE 11 Advanced DataFrame Transformations
1 / 20
MODULE 11 — ADVANCED DATAFRAME TRANSFORMATIONS

Unlock the full power of
complex data shapes

Real-world data is rarely flat. This module teaches you how to handle nested structures (structs, arrays, maps), apply lambda-style Higher Order Functions, pivot and unpivot data, and produce multi-level aggregations with rollup, cube, and grouping sets.

🗺️ Why This Module Matters
After Module 10 you know how to compute across rows using windows. This module teaches you how to compute inside columns — transforming arrays, structs, and maps — and also how to reshape your entire DataFrame from wide to tall or tall to wide.
🏗️
Nested Data
StructType, ArrayType, MapType — work with complex column shapes
💥
Explode & Flatten
Turn arrays into rows and collapse nested arrays
λ
Higher Order Functions
transform, filter, aggregate, exists — lambda functions on arrays
🔄
Pivot / Unpivot
Reshape DataFrames: wide ↔ tall transformations
📦
Rollup / Cube
Multi-level subtotals — the SQL GROUP BY extensions
11.1 — NESTED DATA

StructType Columns

A StructType column is like a mini-DataFrame embedded inside a column — it contains named sub-fields, each with its own type.

🏗️
What is a StructType column?
Concept
Understanding Structs
A Struct is a nested record inside a column. Instead of having separate columns for first_name and last_name, you can have a single name column of type StructType with sub-fields first and last.

You access sub-fields using dot notation: df["name.first"] or col("name.first").
🏠 Analogy
Think of a Struct like a folder inside a filing cabinet. The filing cabinet is your DataFrame, each drawer is a column, but some drawers contain sub-folders (structs) instead of plain papers.
python
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql import functions as F

spark = SparkSession.builder.appName("Module11").getOrCreate()

# Define a schema with a nested StructType
schema = StructType([
    StructField("id", IntegerType()),
    StructField("name", StructType([          # Nested struct
        StructField("first", StringType()),
        StructField("last",  StringType())
    ])),
    StructField("age", IntegerType())
])

data = [
    (1, ("Alice", "Smith"), 30),
    (2, ("Bob",   "Jones"), 25),
    (3, ("Carol", "Lee"),   28),
]

df = spark.createDataFrame(data, schema)
df.printSchema()
# root
#  |-- id: integer
#  |-- name: struct
#  |    |-- first: string
#  |    |-- last: string
#  |-- age: integer

# Access sub-fields with dot notation
df.select("id", "name.first", "name.last", "age").show()
Output
idname.firstname.lastage
1AliceSmith30
2BobJones25
3CarolLee28
Creating Structs with struct() function
You can create a struct on-the-fly from existing columns using F.struct(). This is useful when you want to combine multiple columns into a single nested column.
python
# Flat DataFrame
flat_data = [(1, "Alice", "Smith", 30),
             (2, "Bob", "Jones", 25)]
flat_df = spark.createDataFrame(flat_data, ["id", "first", "last", "age"])

# Combine first + last into a struct column called "name"
nested_df = flat_df.withColumn(
    "name",
    F.struct(
        F.col("first").alias("first"),
        F.col("last").alias("last")
    )
).drop("first", "last")

nested_df.printSchema()
nested_df.show(truncate=False)
# +---+---+------------------+
# |id |age|name              |
# +---+---+------------------+
# |1  |30 |{Alice, Smith}    |
# |2  |25 |{Bob, Jones}      |
# +---+---+------------------+
Flattening a Struct with getField() and *
To go from nested → flat, use col("name.*") to expand all sub-fields, or col("name").getField("first") for a specific field.
python
# Expand all sub-fields of "name" struct into separate columns
flat_again = nested_df.select("id", "age", "name.*")
flat_again.show()
# +---+---+-----+-----+
# |id |age|first|last |
# +---+---+-----+-----+

# Or get a specific sub-field
nested_df.select(F.col("name").getField("first").alias("first_name")).show()
💡 Key Rule
Use dot notation "name.first" in string-based selects. Use .getField() when chaining Column expressions. Use "name.*" to expand all sub-fields at once.
11.1 — NESTED DATA

ArrayType Columns

An ArrayType column stores a list of values in a single cell — like a Python list inside your DataFrame.

📋
Creating and working with Arrays
Array
Creating ArrayType columns
Arrays can be created with F.array() to combine columns, or they naturally appear when reading JSON/nested data. Use F.array_contains(), F.size(), F.sort_array() etc. to work with them.
python
from pyspark.sql.types import ArrayType

# Data where "skills" is already a list
data = [
    (1, "Alice", ["Python", "SQL", "Spark"]),
    (2, "Bob",   ["Java", "Scala"]),
    (3, "Carol", ["Python", "R", "SQL", "Tableau"]),
]
df = spark.createDataFrame(data, ["id", "name", "skills"])
df.printSchema()
# |-- skills: array (nullable = true)
# |    |-- element: string

df.show(truncate=False)
# +---+-----+-----------------------------+
# |id |name |skills                       |
# +---+-----+-----------------------------+
# |1  |Alice|[Python, SQL, Spark]         |
# |2  |Bob  |[Java, Scala]                |
# |3  |Carol|[Python, R, SQL, Tableau]    |
Array utility functions
PySpark has a rich set of built-in array functions in pyspark.sql.functions:
FunctionWhat it doesExample result
size(col)Number of elements3
array_contains(col, val)True if value in arrayTrue/False
sort_array(col)Sort the array[Python, SQL, Spark]
array_distinct(col)Remove duplicates[Python, SQL]
array_union(a, b)Union of two arrays[a, b, c, d]
array_intersect(a, b)Common elements[Python, SQL]
array_except(a, b)Elements in a not in b[Spark]
array_remove(col, val)Remove specific value[Python, SQL]
element_at(col, idx)Get element by index (1-based)"Python"
slice(col, start, len)Sub-array[SQL, Spark]
concat(a, b)Concatenate two arrays[a, b, c, d, e]
python
df.select(
    "name",
    F.size("skills").alias("num_skills"),
    F.array_contains("skills", "Python").alias("knows_python"),
    F.sort_array("skills").alias("sorted_skills"),
    F.element_at("skills", 1).alias("first_skill")
).show(truncate=False)
Output
namenum_skillsknows_pythonsorted_skillsfirst_skill
Alice3true[Python, SQL, Spark]Python
Bob2false[Java, Scala]Java
Carol4true[Python, R, SQL, Tableau]Python
11.1 — NESTED DATA

MapType Columns

A MapType column stores key-value pairs — like a Python dictionary — inside a single cell.

🗺️
Creating and querying Maps
Map
What is MapType?
MapType stores key → value pairs. Use F.create_map() to build a map from columns, and F.map_keys() / F.map_values() to extract parts of it.
🧪 Example
You have exam scores: {"math": 90, "science": 85, "english": 78} stored in one column per student.
python
from pyspark.sql.types import MapType

data = [
    (1, "Alice", {"math": 90, "science": 85, "english": 78}),
    (2, "Bob",   {"math": 70, "science": 92}),
    (3, "Carol", {"math": 88, "english": 95}),
]
df = spark.createDataFrame(data, ["id", "name", "scores"])

# Access a specific key
df.select(
    "name",
    F.col("scores")["math"].alias("math_score"),
    F.map_keys("scores").alias("subjects"),
    F.map_values("scores").alias("marks"),
    F.size(F.map_keys("scores")).alias("num_subjects")
).show(truncate=False)
Output
namemath_scoresubjectsnum_subjects
Alice90[math, science, english]3
Bob70[math, science]2
Carol88[math, english]2
create_map() — build a Map from columns
Use F.create_map(key_col, value_col, ...) to create a map from flat columns. Keys and values must alternate.
python
flat = spark.createDataFrame([
    ("Alice", 90, 85),
    ("Bob",   70, 92),
], ["name", "math", "science"])

# Build a map {"math": val, "science": val} per row
result = flat.withColumn(
    "scores",
    F.create_map(
        F.lit("math"),    F.col("math"),
        F.lit("science"), F.col("science")
    )
)
result.select("name", "scores").show(truncate=False)
# +-----+--------------------------+
# |name |scores                    |
# |Alice|{math -> 90, science -> 85}|
# |Bob  |{math -> 70, science -> 92}|
11.2 — EXPLODE & FLATTEN

explode() & explode_outer()

Convert an array or map column into multiple rows — one row per element.

💥
explode() — arrays and maps into rows
Transform
explode() on Arrays
explode() takes each element of an array and creates a new row for it. The other columns are duplicated. Rows with null arrays are dropped.
💥 Analogy
Imagine each student's skills list as a balloon. explode() pops the balloon and creates one row per skill — Alice with 3 skills becomes 3 rows.
python
data = [
    (1, "Alice", ["Python", "SQL", "Spark"]),
    (2, "Bob",   ["Java", "Scala"]),
    (3, "Dave",  None),              # null array
]
df = spark.createDataFrame(data, ["id", "name", "skills"])

# explode: null arrays are dropped
df.select("id", "name", F.explode("skills").alias("skill")).show()
Output (Dave is dropped — null array)
idnameskill
1AlicePython
1AliceSQL
1AliceSpark
2BobJava
2BobScala
explode_outer() — keep null rows
explode_outer() works like explode but keeps rows with null or empty arrays, setting the exploded column to null.
python
# explode_outer: null arrays produce a row with null skill
df.select("id", "name", F.explode_outer("skills").alias("skill")).show()
Output (Dave now appears with null skill)
idnameskill
1AlicePython
1AliceSQL
1AliceSpark
2BobJava
2BobScala
3Davenull
explode() on Maps
When used on a MapType column, explode() creates two columns: one for the key and one for the value.
python
scores_data = [
    ("Alice", {"math": 90, "science": 85}),
    ("Bob",   {"math": 70, "english": 88}),
]
scores_df = spark.createDataFrame(scores_data, ["name", "scores"])

scores_df.select(
    "name",
    F.explode("scores").alias("subject", "score")
).show()
# +-----+-------+-----+
# |name |subject|score|
# +-----+-------+-----+
# |Alice|math   |90   |
# |Alice|science|85   |
# |Bob  |math   |70   |
# |Bob  |english|88   |
11.2 — EXPLODE & FLATTEN

posexplode()

Like explode(), but also gives you the position (index) of each element in the array.

🔢
posexplode — explode with index
Transform
posexplode() usage
posexplode() returns two columns: pos (the 0-based index) and col (the value). This is useful when you need to know the original position of an element.
python
data = [("Alice", ["Python", "SQL", "Spark"]),
        ("Bob",   ["Java", "Scala"])]
df = spark.createDataFrame(data, ["name", "skills"])

df.select(
    "name",
    F.posexplode("skills").alias("pos", "skill")
).show()
Output
nameposskill
Alice0Python
Alice1SQL
Alice2Spark
Bob0Java
Bob1Scala
💡 posexplode_outer
Just like explode has explode_outer, posexplode has posexplode_outer() which keeps null arrays as a row with pos=null, col=null.
11.2 — EXPLODE & FLATTEN

flatten() & inline()

Collapse nested arrays into a single-level array, or expand an array of structs into columns.

📐
flatten() and inline()
Nested
flatten() — arrays of arrays → single array
F.flatten() collapses a nested array (array of arrays) into a single flat array. It only goes one level deep.
python
data = [
    ("Alice", [["Python", "SQL"], ["Spark", "Hadoop"]]),
    ("Bob",   [["Java"], ["Scala", "Kotlin"]]),
]
df = spark.createDataFrame(data, ["name", "skill_groups"])

df.select(
    "name",
    F.flatten("skill_groups").alias("all_skills")
).show(truncate=False)
# +-----+------------------------------+
# |name |all_skills                    |
# |Alice|[Python, SQL, Spark, Hadoop]  |
# |Bob  |[Java, Scala, Kotlin]         |
inline() — array of structs → multiple rows with columns
inline() explodes an array of structs and expands each struct field into a separate column — all in one step.
python
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType

schema = StructType([
    StructField("name", StringType()),
    StructField("purchases", ArrayType(StructType([
        StructField("item", StringType()),
        StructField("price", IntegerType())
    ])))
])

data = [
    ("Alice", [("Laptop", 1200), ("Mouse", 25)]),
    ("Bob",   [("Keyboard", 80)]),
]
df = spark.createDataFrame(data, schema)

# inline expands array-of-structs into rows + columns
df.select("name", F.inline("purchases")).show()
# +-----+--------+-----+
# |name |item    |price|
# +-----+--------+-----+
# |Alice|Laptop  |1200 |
# |Alice|Mouse   |25   |
# |Bob  |Keyboard|80   |
11.3 — HIGHER ORDER FUNCTIONS

transform()

Apply a lambda function to every element in an array column — like Python's map(), but for DataFrame columns.

λ
transform() — map over array elements
HOF
What is a Higher Order Function (HOF)?
Higher Order Functions take another function (lambda) as an argument. They let you process arrays without exploding them into rows first. PySpark supports them in SQL and via pyspark.sql.functions.
λ Analogy
transform(array, x → x * 2) is like writing a for-loop over the array: "for each element x, return x*2". But it runs in a distributed, optimized way inside Spark.
transform() — apply a function to every element
F.transform(col, lambda x: expr) applies the lambda to each element and returns a new array of the same size.
python
data = [
    ("Alice", [10, 20, 30]),
    ("Bob",   [5,  15, 25]),
]
df = spark.createDataFrame(data, ["name", "scores"])

# Double every score in the array
df.select(
    "name",
    F.transform("scores", lambda x: x * 2).alias("doubled")
).show(truncate=False)
# +-----+-----------+
# |name |doubled    |
# |Alice|[20, 40, 60]|
# |Bob  |[10, 30, 50]|

# Uppercase every string in an array
str_data = [("Alice", ["python", "sql"]), ("Bob", ["java"])]
str_df = spark.createDataFrame(str_data, ["name", "skills"])

str_df.select(
    "name",
    F.transform("skills", lambda x: F.upper(x)).alias("skills_upper")
).show(truncate=False)
# |Alice|[PYTHON, SQL]|
# |Bob  |[JAVA]       |
transform() with index (i, x)
You can also use a two-argument lambda (x, i) where i is the element's index (0-based).
python
# Add the index to each score: score + position
df.select(
    "name",
    F.transform("scores", lambda x, i: x + i).alias("score_plus_idx")
).show(truncate=False)
# Alice: [10+0, 20+1, 30+2] = [10, 21, 32]
# Bob:   [5+0,  15+1, 25+2] = [5, 16, 27]
11.3 — HIGHER ORDER FUNCTIONS

filter() — Higher Order

Keep only elements in an array that satisfy a condition — like Python's filter(), but on array columns.

🔽
filter() HOF — filter array elements
HOF
F.filter() on arrays
F.filter(array_col, lambda x: condition) returns a new array containing only elements for which the condition is True. Elements failing the condition are removed.
⚠️ Name Conflict
F.filter() is the Higher Order Function. Don't confuse it with DataFrame.filter() which filters rows. They're different things!
python
data = [
    ("Alice", [10, 55, 30, 80, 15]),
    ("Bob",   [5,  70, 90, 12]),
]
df = spark.createDataFrame(data, ["name", "scores"])

# Keep only scores >= 50
df.select(
    "name",
    F.filter("scores", lambda x: x >= 50).alias("high_scores")
).show(truncate=False)
# +-----+-----------+
# |name |high_scores|
# |Alice|[55, 80]   |
# |Bob  |[70, 90]   |

# Filter strings starting with 'P'
str_data = [("Alice", ["Python", "SQL", "Pandas", "Spark"])]
str_df = spark.createDataFrame(str_data, ["name", "skills"])

str_df.select(
    F.filter("skills", lambda x: x.startswith("P")).alias("p_skills")
).show(truncate=False)
# [Python, Pandas]
11.3 — HIGHER ORDER FUNCTIONS

aggregate()

Reduce an array to a single value using an accumulator — like Python's reduce().

aggregate() HOF — fold an array into one value
HOF
aggregate(array, zero, merge_func, finish_func?)
F.aggregate(col, zero, merge) works like reduce():
  • zero — the starting value (accumulator initializer)
  • mergelambda acc, x: ... — how to combine accumulator with each element
  • finish — (optional) transform the final accumulator
python
data = [
    ("Alice", [10, 20, 30]),
    ("Bob",   [5, 15, 25]),
]
df = spark.createDataFrame(data, ["name", "scores"])

# Sum all scores in the array
df.select(
    "name",
    F.aggregate(
        "scores",
        F.lit(0),                    # start at 0
        lambda acc, x: acc + x      # add each element
    ).alias("total")
).show()
# Alice: 10+20+30 = 60
# Bob:   5+15+25  = 45

# Compute average using finish function
df.select(
    "name",
    F.aggregate(
        "scores",
        F.struct(F.lit(0).alias("sum"), F.lit(0).alias("cnt")),
        lambda acc, x: F.struct(
            (acc["sum"] + x).alias("sum"),
            (acc["cnt"] + 1).alias("cnt")
        ),
        lambda acc: acc["sum"] / acc["cnt"]   # finish: divide
    ).alias("avg_score")
).show()
# Alice: 60/3 = 20.0
# Bob:   45/3 = 15.0
11.3 — HIGHER ORDER FUNCTIONS

exists() & forall()

Test whether any or all elements in an array satisfy a condition — returning a Boolean.

exists() and forall()
HOF
exists() — does any element match?
F.exists(col, lambda x: condition) returns True if at least one element in the array satisfies the condition.
python
data = [
    ("Alice", [10, 55, 30]),
    ("Bob",   [5,  15, 25]),
]
df = spark.createDataFrame(data, ["name", "scores"])

# Does anyone have at least one score >= 50?
df.select(
    "name",
    F.exists("scores", lambda x: x >= 50).alias("has_high_score")
).show()
# Alice: True  (55 >= 50)
# Bob:   False (none >= 50)
forall() — do ALL elements match?
F.forall(col, lambda x: condition) returns True only if every element satisfies the condition.
python
# Are ALL scores >= 10?
df.select(
    "name",
    F.forall("scores", lambda x: x >= 10).alias("all_above_10")
).show()
# Alice: True  (10, 55, 30 all >= 10)
# Bob:   False (5 < 10)
FunctionReturns True when...Equivalent to...
exists()At least 1 element matchesPython's any()
forall()Every element matchesPython's all()
11.3 — HIGHER ORDER FUNCTIONS

zip_with()

Merge two arrays element-by-element using a custom function — like Python's zip() + map() combined.

🤝
zip_with() — combine two arrays
HOF
zip_with() usage
F.zip_with(array1, array2, lambda x, y: expr) pairs elements at the same index from both arrays and applies the function to each pair, producing a new array.
python
data = [
    ("Alice", [80, 90, 70], [85, 88, 72]),  # midterm, final
    ("Bob",   [60, 75],      [65, 80]),
]
df = spark.createDataFrame(data, ["name", "midterm", "final"])

# Average of midterm and final for each subject
df.select(
    "name",
    F.zip_with(
        "midterm", "final",
        lambda m, f: (m + f) / 2
    ).alias("avg_scores")
).show(truncate=False)
# Alice: [(80+85)/2, (90+88)/2, (70+72)/2] = [82.5, 89.0, 71.0]
# Bob:   [(60+65)/2, (75+80)/2]             = [62.5, 77.5]
💡 Note
If arrays have different lengths, zip_with uses null for missing elements in the shorter array.
11.4 — PIVOT & UNPIVOT

pivot()

Turn row values into column headers — transforming a "tall" DataFrame into a "wide" one.

↔️
pivot() — tall to wide transformation
Reshape
What is Pivoting?
Pivoting converts unique values in one column into new column headers. You group by some columns, pivot on a category column, and aggregate a value column.
🔄 Analogy
Imagine a sales table with rows: (Region, Quarter, Sales). Pivoting on Quarter makes Q1, Q2, Q3, Q4 become separate columns — one row per region.
python
# Tall (long) format
data = [
    ("Alice", "Math",    90),
    ("Alice", "Science", 85),
    ("Alice", "English", 78),
    ("Bob",   "Math",    70),
    ("Bob",   "Science", 92),
    ("Bob",   "English", 88),
]
df = spark.createDataFrame(data, ["student", "subject", "score"])

# pivot: group by student, pivot on subject, aggregate score
pivot_df = df.groupBy("student").pivot("subject").agg(F.first("score"))
pivot_df.show()
Output (wide format)
studentEnglishMathScience
Alice789085
Bob887092
Specifying pivot values explicitly (performance tip)
By default, PySpark scans all data to find unique pivot values. Specify them explicitly to avoid this extra scan — much faster on large datasets.
python
# Specify pivot values explicitly — avoids extra data scan
pivot_df = df.groupBy("student") \
             .pivot("subject", ["Math", "Science", "English"]) \
             .agg(F.first("score"))
pivot_df.show()
💡 Performance Rule
Always pass the list of pivot values when you know them in advance. Spark won't need to make an extra pass over the data to discover them.
Pivot with multiple aggregations
You can pivot with multiple aggregations — PySpark creates column names like subject_agg.
python
sales_data = [
    ("North", "Q1", 100), ("North", "Q1", 120),
    ("North", "Q2", 150), ("South", "Q1", 80),
    ("South", "Q2", 90),  ("South", "Q2", 95),
]
sales = spark.createDataFrame(sales_data, ["region", "quarter", "sales"])

# Pivot with sum AND count
sales.groupBy("region") \
     .pivot("quarter", ["Q1", "Q2"]) \
     .agg(F.sum("sales").alias("total"), F.count("sales").alias("cnt")) \
     .show()
# Columns: region | Q1_total | Q1_cnt | Q2_total | Q2_cnt
11.4 — PIVOT & UNPIVOT

unpivot() / stack()

The reverse of pivot — turn column headers back into row values (wide → tall).

↕️
unpivot / stack — wide to tall
Reshape
unpivot() — PySpark 3.4+ native method
In PySpark 3.4+, DataFrame.unpivot(ids, values, variableColumnName, valueColumnName) is available natively.
python
# Wide format DataFrame
wide_data = [("Alice", 90, 85, 78),
             ("Bob",   70, 92, 88)]
wide_df = spark.createDataFrame(wide_data, ["student", "Math", "Science", "English"])

# PySpark 3.4+ native unpivot
tall_df = wide_df.unpivot(
    ids=["student"],                          # columns to keep as-is
    values=["Math", "Science", "English"],  # columns to melt
    variableColumnName="subject",             # new key column name
    valueColumnName="score"                   # new value column name
)
tall_df.show()
Output (tall/long format)
studentsubjectscore
AliceMath90
AliceScience85
AliceEnglish78
BobMath70
BobScience92
BobEnglish88
stack() — SQL-based unpivot (all PySpark versions)
For older PySpark versions, use F.expr("stack(n, 'key1', col1, 'key2', col2, ...)") via selectExpr. It's a SQL expression that stacks multiple columns into rows.
python
# Works in all PySpark versions
tall_df2 = wide_df.selectExpr(
    "student",
    "stack(3, 'Math', Math, 'Science', Science, 'English', English) as (subject, score)"
)
tall_df2.show()
# Same output as unpivot() above
MethodPySpark VersionNotes
df.unpivot()3.4+Native, clean API
selectExpr("stack(...)")All versionsSQL expression, slightly verbose
11.5 — ADVANCED AGGREGATIONS

rollup()

Compute subtotals at multiple hierarchy levels in a single pass — like a spreadsheet subtotal feature.

🏔️
rollup() — hierarchical subtotals
Aggregation
What is rollup?
rollup(A, B, C) computes aggregations for every prefix combination of the group-by columns:
  • (A, B, C) — most granular
  • (A, B) — subtotal for A+B
  • (A) — subtotal for A only
  • () — grand total
Null values in the result indicate a subtotal/grand total row.
🏔️ Analogy
Think of a company hierarchy: Country → Region → City. Rollup gives you sales for each City, then each Region (summing cities), then each Country (summing regions), then the grand total — all at once.
python
data = [
    ("India",   "South", "Bangalore", 500),
    ("India",   "South", "Chennai",   300),
    ("India",   "North", "Delhi",     700),
    ("USA",     "West",  "LA",        800),
    ("USA",     "East",  "NYC",       900),
]
df = spark.createDataFrame(data, ["country", "region", "city", "sales"])

# rollup: totals at every hierarchical level
rollup_df = df.rollup("country", "region", "city") \
              .agg(F.sum("sales").alias("total_sales")) \
              .orderBy("country", "region", "city")

rollup_df.show()
Output (null = subtotal/grand total row)
countryregioncitytotal_sales
IndiaNorthDelhi700
IndiaNorthnull700 ← North subtotal
IndiaSouthBangalore500
IndiaSouthChennai300
IndiaSouthnull800 ← South subtotal
Indianullnull1500 ← India total
USAEastNYC900
USAWestLA800
USAnullnull1700 ← USA total
nullnullnull3200 ← Grand total
11.5 — ADVANCED AGGREGATIONS

cube()

Like rollup, but computes aggregations for every possible combination of group-by columns, not just prefixes.

🎲
cube() — all combinations of subtotals
Aggregation
rollup vs cube
rollup(A, B) gives: (A,B), (A), (). It's hierarchical — only prefixes.

cube(A, B) gives: (A,B), (A), (B), (). It generates every subset — including cross-cutting combinations like just B totals (without A).
Groupingrollup(country, region)cube(country, region)
(country, region)
(country) only
(region) only
Grand total ()
python
data = [
    ("India", "Online",  500),
    ("India", "Offline", 300),
    ("USA",   "Online",  800),
    ("USA",   "Offline", 400),
]
df = spark.createDataFrame(data, ["country", "channel", "sales"])

cube_df = df.cube("country", "channel") \
            .agg(F.sum("sales").alias("total_sales")) \
            .orderBy("country", "channel")

cube_df.show()
Output
countrychanneltotal_sales
IndiaOffline300
IndiaOnline500
Indianull800 ← India total
USAOffline400
USAOnline800
USAnull1200 ← USA total
nullOffline700 ← All Offline (🆕 cube-only)
nullOnline1300 ← All Online (🆕 cube-only)
nullnull2000 ← Grand total
11.5 — ADVANCED AGGREGATIONS

Grouping Sets

The most flexible option — specify exactly which combinations of columns you want subtotals for. A superset of both rollup and cube.

🎯
GROUPING SETS — custom combinations
Aggregation
GROUPING SETS via SQL
PySpark doesn't have a native .groupingSets() method — you use SQL syntax via spark.sql() or df.createOrReplaceTempView() + SQL. This gives you precise control over which group combinations are computed.
python
# Register the DataFrame as a SQL temp view
df.createOrReplaceTempView("sales")

# GROUPING SETS: only compute (country,channel) and (country) and grand total
# — skip the (channel) only combination that cube would include
result = spark.sql("""
  SELECT country, channel, SUM(sales) AS total_sales
  FROM sales
  GROUP BY GROUPING SETS (
    (country, channel),   -- most granular
    (country),            -- subtotal by country only
    ()                    -- grand total
  )
  ORDER BY country, channel
""")
result.show()
Output (only the combinations you asked for)
countrychanneltotal_sales
IndiaOffline300
IndiaOnline500
Indianull800
USAOffline400
USAOnline800
USAnull1200
nullnull2000
💡 Comparison Summary
rollup(A,B) = GROUPING SETS ((A,B),(A),())
cube(A,B) = GROUPING SETS ((A,B),(A),(B),())
GROUPING SETS = you decide exactly which sets
11.5 — ADVANCED AGGREGATIONS

grouping() & grouping_id()

Distinguish real nulls from subtotal nulls in rollup/cube results — identify which level of aggregation each row belongs to.

🏷️
grouping() and grouping_id()
Helper
The Problem: Real null vs Subtotal null
After a rollup or cube, a null in "country" means it's a grand total row. But what if your data already has null values in "country"? You can't tell them apart. grouping() and grouping_id() solve this.
grouping(col) — is this column "nulled out" by rollup/cube?
F.grouping("country") returns 1 if the column was "nulled out" for this row (i.e., it's a subtotal row), or 0 if it's a genuine group value.
python
rollup_df = df.rollup("country", "channel") \
              .agg(
                  F.sum("sales").alias("total"),
                  F.grouping("country").alias("is_country_subtotal"),
                  F.grouping("channel").alias("is_channel_subtotal")
              )
rollup_df.show()
# is_country_subtotal=1 means this row is a country-level or grand total
# is_channel_subtotal=1 means channel was rolled up for this row
grouping_id() — a bitmask of which columns are subtotals
F.grouping_id("country", "channel") returns an integer bitmask:
  • 0 (binary 00) = both columns are real group values → most granular row
  • 1 (binary 01) = channel is subtotal, country is real → country subtotal
  • 3 (binary 11) = both are subtotals → grand total row
python
rollup_df = df.rollup("country", "channel") \
              .agg(
                  F.sum("sales").alias("total"),
                  F.grouping_id("country", "channel").alias("gid")
              )

# Filter to show only grand total rows (gid = 3)
rollup_df.filter(F.col("gid") == 3).show()

# Label rows using grouping_id
rollup_df.withColumn(
    "row_type",
    F.when(F.col("gid") == 0, "Detail")
     .when(F.col("gid") == 1, "Country Total")
     .when(F.col("gid") == 3, "Grand Total")
).show()
💡 Use Case
Use grouping_id() to filter or label specific aggregation levels in rollup/cube output — for example, to build Excel-style reports with subtotal rows clearly identified.
MODULE 11 — REFERENCE

Module 11 Cheat Sheet

Everything in one place — nested data, HOFs, pivot, and advanced aggregations.

📋
Complete Quick Reference
Nested Data
OperationCode
Access struct fieldcol("name.first") or col("name").getField("first")
Expand structselect("name.*")
Create structF.struct(col("a"), col("b"))
Array sizeF.size("skills")
Array containsF.array_contains("skills", "Python")
Array elementF.element_at("skills", 1) (1-indexed)
Map valuecol("scores")["math"]
Map keysF.map_keys("scores")
Create mapF.create_map(lit("k"), col("v"))
Explode & Flatten
FunctionBehaviour
explode(col)Array/Map → rows. Drops nulls.
explode_outer(col)Array/Map → rows. Keeps nulls as null row.
posexplode(col)Like explode but adds index column.
posexplode_outer(col)posexplode + keeps nulls.
flatten(col)Array<Array<T>> → Array<T> (one level).
inline(col)Array<Struct> → rows + columns.
Higher Order Functions (HOF)
FunctionSignatureWhat it does
transform(col, x → expr)Map over array elements
filter(col, x → bool)Keep elements where true
aggregate(col, zero, (acc,x) → acc, acc → result)Reduce array to one value
exists(col, x → bool)True if any element matches
forall(col, x → bool)True if all elements match
zip_with(a, b, (x,y) → expr)Element-wise combine two arrays
Pivot & Unpivot
OperationCode
Pivot (tall → wide)df.groupBy("id").pivot("cat", ["v1","v2"]).agg(F.sum("val"))
Unpivot (wide → tall)df.unpivot(["id"], ["col1","col2"], "key", "val") (PySpark 3.4+)
Stack (wide → tall)df.selectExpr("id", "stack(2,'k1',c1,'k2',c2) as (key,val)")
Advanced Aggregations Decision Guide
NeedUse
Hierarchical subtotals (A→AB→ABC→total)rollup(A,B,C)
Every possible combination of subtotalscube(A,B,C)
Only specific combinationsGROUPING SETS((A,B),(A),...) via SQL
Identify if a null is subtotal or realgrouping(col)
Label aggregation level as integergrouping_id(A,B)
MODULE 11 — QUIZ

Test Your Knowledge

5 questions covering nested data, HOFs, pivot, and advanced aggregations.