MODULE 15 Partitioning
0 / 9
Module 15 · Partitioning

How Spark Divides and Conquers Data

Partitioning is the single most impactful concept in Spark performance. It controls how your data is split across the cluster, how many parallel tasks run, whether shuffles happen, and how much work each executor does. Master this and you master Spark performance.

✂️
Split Data
Partitions divide data into independent chunks that run in parallel on different executors.
Parallelism
More partitions = more parallelism. But too many = scheduling overhead. Balance is key.
🔍
Pruning
Smart partitioning lets Spark skip entire partitions — reading only what you need.
⚖️
Balance
Even data distribution avoids slow tasks caused by data skew on a single partition.
Partitioning in the Spark Pipeline
📥 Read Data ✂️ Partition ⚡ Parallel Tasks 🔀 Shuffle (if needed) ✅ Result
📌 Golden Rule
Target 2–4 partitions per CPU core available in your cluster. With 100 cores, aim for 200–400 partitions. Each partition should be roughly 128MB–256MB in memory.
15.1

Spark Partitions

Understanding the difference between logical and physical partitions, and how Spark decides how many partitions to create.

🧩
Logical Partitions
Concept
What is a Logical Partition?
A logical partition is Spark's internal view of a data chunk. It's a virtual division of your dataset — a slice that Spark knows about at the planning stage, before any actual computation. Each logical partition corresponds to one task in Spark's execution engine.
🍕 Analogy
Think of a large pizza (your data). Cutting it into 8 slices creates 8 logical partitions. Each person (executor core) gets one slice to eat in parallel. More slices = more people can eat at once, but making the slices too thin wastes time cutting.
When you call df.rdd.getNumPartitions(), you're asking "how many logical slices does Spark think this data has right now?"
python
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("partitions").getOrCreate()

# Create a simple DataFrame
data = [(1, "Alice"), (2, "Bob"), (3, "Charlie"),
        (4, "Dave"), (5, "Eve")]
df = spark.createDataFrame(data, ["id", "name"])

# Check logical partitions
print(f"Number of logical partitions: {df.rdd.getNumPartitions()}")
# Output: Number of logical partitions: 8  (depends on default parallelism)
What is a Physical Partition?
A physical partition is the actual file or block on disk (or in memory) that stores data. When you read a Parquet file, each Parquet row-group or file segment becomes a physical partition (task). When you write data with df.write.partitionBy("date"), you're creating physical folder partitions on disk.
📂 Logical vs Physical
Logical: How Spark splits data in memory for processing (tasks).
Physical: How data is actually stored on disk (files/folders).

Reading 10 Parquet files → 10 physical partitions → Spark creates 10 logical partitions (tasks).
python
# Reading CSV: each file chunk = one logical partition
df_csv = spark.read.csv("s3://bucket/data/*.csv", header=True)
print(df_csv.rdd.getNumPartitions())  # e.g., 12 (one per file or per 128MB block)

# Reading Parquet: one logical partition per row group
df_parquet = spark.read.parquet("s3://bucket/data/sales.parquet")
print(df_parquet.rdd.getNumPartitions())  # e.g., 4

# Writing physical partitions to disk
df.write.partitionBy("country").parquet("s3://bucket/output/")
# Creates: output/country=US/, output/country=IN/, output/country=UK/ ...
Partition Count — How Spark Decides
Spark determines the initial partition count from several sources:
SourceDefault Partition CountConfig
Reading files (HDFS/S3)One per 128MB blockspark.sql.files.maxPartitionBytes
After shuffle (groupBy, join)200spark.sql.shuffle.partitions
parallelize() from collectionDefault parallelismspark.default.parallelism
repartition(N)N (you specify)N/A
python
# Default shuffle partitions = 200. Change based on data size:
spark.conf.set("spark.sql.shuffle.partitions", "50")
# For small data: reduce to 50 (avoids 200 tiny tasks)
# For large data (TB scale): increase to 1000+

# See current value
print(spark.conf.get("spark.sql.shuffle.partitions"))

# Check partition count after groupBy (triggers shuffle)
result = df.groupBy("country").count()
print(result.rdd.getNumPartitions())  # Will be 200 by default

# Rule of thumb: partition size = total data size / num partitions
# Target: 128MB - 256MB per partition
# Example: 10GB data → ~50 to 80 partitions
⚠️ Common Mistake
The default spark.sql.shuffle.partitions = 200 was designed for medium datasets. For a 100MB dataset, 200 shuffle partitions creates 200 tiny tasks with 0.5MB each — wasteful! Always tune this per job.
15.2

Repartition

repartition() is a wide transformation that performs a full shuffle to redistribute data evenly across a new number of partitions.

🔀
Hash Repartition
Default
How Hash Repartition Works
Hash repartition is the default behavior of repartition(). Spark computes a hash of one or more column values and uses modulo arithmetic to assign each row to a partition. This ensures rows with the same key always go to the same partition.
🗄️ Analogy
Think of a filing cabinet with 10 drawers. Every document gets a hash number (e.g., based on customer name). Document "Alice" → hash 42 → drawer 42 % 10 = drawer 2. All "Alice" documents always go to drawer 2. This groups related data together.
python
from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id

spark = SparkSession.builder.appName("repartition").getOrCreate()

data = [(1, "Alice", "US"), (2, "Bob", "UK"), (3, "Charlie", "US"),
        (4, "Dave", "IN"), (5, "Eve", "UK"), (6, "Frank", "US")]
df = spark.createDataFrame(data, ["id", "name", "country"])

# ---- Hash repartition by number only (round-robin hash) ----
df_repartitioned = df.repartition(4)
print(df_repartitioned.rdd.getNumPartitions())  # 4

# ---- Hash repartition by column ----
# All rows with same "country" go to the same partition
df_by_country = df.repartition(4, "country")

# See which partition each row landed in
df_by_country.withColumn("partition_id", spark_partition_id()) \
              .orderBy("country") \
              .show()
# Output: All "US" rows → same partition_id
#         All "UK" rows → same partition_id
#         All "IN" rows → same partition_id

# ---- Multi-column hash repartition ----
df_multi = df.repartition(8, "country", "name")
# Rows with same (country, name) combo → same partition
💡 When to Use Hash Repartition
Use it when you'll do joins or aggregations on a specific column. Pre-partitioning by the join key means Spark won't need to reshuffle during the join — rows are already co-located.
📊
Range Repartition
Ordered
How Range Repartition Works
Range repartition (via repartitionByRange()) samples the data first to understand its distribution, then divides it into ranges so each partition holds a contiguous range of values. The output partitions will be sorted within each partition.
📚 Analogy
Like dividing a dictionary into volumes: Vol 1 = A–F, Vol 2 = G–M, Vol 3 = N–S, Vol 4 = T–Z. Each volume covers a range of letters. Reading in order is very efficient.
python
from pyspark.sql.functions import spark_partition_id, col

data = [(1, 100), (2, 200), (3, 50), (4, 400), (5, 350),
        (6, 75), (7, 275), (8, 450)]
df = spark.createDataFrame(data, ["id", "amount"])

# Range repartition: divides "amount" into 3 sorted ranges
df_range = df.repartitionByRange(3, "amount")

df_range.withColumn("partition_id", spark_partition_id()) \
        .orderBy("amount") \
        .show()

# Approximate output:
# partition 0: amount 50, 75, 100        ← low values
# partition 1: amount 200, 275           ← mid values
# partition 2: amount 350, 400, 450      ← high values

# Also supports descending order
df_range_desc = df.repartitionByRange(3, col("amount").desc())

# Great before writing sorted data to files
df_range.write.parquet("s3://bucket/sorted-output/")
# Each file will have sorted amounts → better predicate pushdown on reads
ℹ️ Range vs Hash
Hash: Groups same key values together. Best for joins/aggregations.
Range: Groups ranges of values. Best for ordered/sorted output and range queries.
🔄
Round Robin Repartition
Even Distribution
How Round Robin Works
When you call repartition(N) without specifying a column, Spark uses round-robin — it distributes rows sequentially across partitions in order. Row 1 → partition 0, Row 2 → partition 1, ..., Row N+1 → partition 0 again. This creates perfectly balanced partitions but doesn't group any related rows together.
python
# Round robin — NO column specified
# Rows distributed evenly, no grouping by key
df_rr = df.repartition(4)  # 4 partitions, balanced row count

# Check balance across partitions
df_rr.withColumn("partition_id", spark_partition_id()) \
     .groupBy("partition_id") \
     .count() \
     .orderBy("partition_id") \
     .show()
# Output: Each partition has ~equal number of rows
# +------------+-----+
# |partition_id|count|
# +------------+-----+
# |           0|    2|
# |           1|    1|
# |           2|    2|
# |           3|    1|
# +------------+-----+

# Use case: fixing skewed partitions before a write
df_skewed.repartition(100).write.parquet("output/")
# Creates 100 balanced files on disk
MethodColumn Required?Groups Related Rows?Best For
repartition(N)NoNo (round-robin)Balancing row counts
repartition(N, col)YesYes (hash)Joins, aggregations
repartitionByRange(N, col)YesYes (range)Sorted writes, range scans
15.3

Coalesce

coalesce() reduces the number of partitions without a full shuffle — making it much cheaper than repartition() when you only need fewer partitions.

🗜️
Narrow Transformation — No Shuffle
Efficient
How Coalesce Works
coalesce() is a narrow transformation — it doesn't shuffle data across the network. Instead, it merges existing partitions together on the same executor. Think of it as just "combining adjacent files" without moving data between machines.
📋 Analogy
You have 100 sticky notes spread across 10 piles (partitions). coalesce(3) picks up nearby piles and stacks them together into 3 stacks — no need to re-sort or move notes across the room. It's fast because you're just combining existing stacks locally.
python
# Start with many partitions
df_big = spark.range(1000000).repartition(100)
print(df_big.rdd.getNumPartitions())  # 100

# Coalesce down to 10 — NO shuffle, just merges partitions
df_small = df_big.coalesce(10)
print(df_small.rdd.getNumPartitions())  # 10

# Classic use case: before writing to avoid many small files
df_result = df.groupBy("country").count()
# After groupBy, shuffle partitions = 200 (many empty!)
df_result.coalesce(5).write.parquet("output/")
# Creates only 5 output files, not 200
Coalesce vs Repartition — When to Use Which
This is one of the most common interview questions. The key difference:
Featurecoalesce()repartition()
Shuffle❌ No shuffle✅ Full shuffle
SpeedFastSlower
Can increase partitions?❌ No (only decrease)✅ Yes
Data balanceMay be unevenEven distribution
Transformation typeNarrowWide
Best forReducing before writeRedistribution, fixing skew
python
# ✅ GOOD: coalesce when reducing partitions (no shuffle needed)
df.coalesce(5).write.parquet("output/")

# ❌ WRONG: coalesce cannot increase partitions
df_4_partitions.coalesce(100)  # Still stays at 4! Ignored.

# ✅ GOOD: repartition when increasing or needing balanced distribution
df_4_partitions.repartition(100)  # Now has 100 balanced partitions

# ✅ GOOD: repartition when fixing skew
df_skewed.repartition(200).write.parquet("output/")

# Rule: "Want fewer? Use coalesce. Want more or need balance? Use repartition."
⚠️ Coalesce Trap
If you coalesce too aggressively (e.g., from 200 to 1), all data goes to a single executor — losing parallelism. It may run out of memory! Keep at least a few partitions. For writing, 5–20 is often a good target.
15.4

Partition Pruning

Partition pruning is Spark's ability to skip reading partitions that can't contain the data you need. This is the biggest performance win from proper partitioning.

✂️
Static Partition Pruning
Planning Time
What is Static Partition Pruning?
Static partition pruning happens at query planning time (before execution). When you filter on a partition column with a literal value (a fixed value, not another column), Spark removes the irrelevant partition directories from the scan entirely.
📁 Analogy
You have folders: year=2022/, year=2023/, year=2024/. You query WHERE year = 2024. Static pruning means Spark never even opens the 2022 and 2023 folders. It skips 66% of the data before reading a single byte!
python
# Write partitioned data to S3/HDFS
df_sales.write \
    .partitionBy("year", "month") \
    .parquet("s3://bucket/sales/")

# On disk, it creates:
# s3://bucket/sales/year=2022/month=01/data.parquet
# s3://bucket/sales/year=2022/month=02/data.parquet
# s3://bucket/sales/year=2023/month=01/data.parquet
# ... etc.

# Now read with a LITERAL filter → Static Pruning kicks in
df = spark.read.parquet("s3://bucket/sales/")

# ONLY reads year=2024/ folders — skips all other years!
df_2024 = df.filter(col("year") == 2024)

# Even better — multi-level pruning
# Only reads year=2024/month=03/ — one folder!
df_march = df.filter((col("year") == 2024) & (col("month") == 3))

# Verify pruning happened in the explain plan
df_march.explain("formatted")
# Look for: "PartitionFilters: [(year = 2024), (month = 3)]"
# This means pruning is working!
Dynamic Partition Pruning (DPP)
Runtime Optimization
What is Dynamic Partition Pruning?
Dynamic Partition Pruning (DPP) is more advanced — it happens at runtime. When you join a large partitioned fact table with a small dimension table, Spark first scans the dimension table, collects the relevant keys, and then uses those keys to prune partitions in the fact table on the fly.
ℹ️ DPP Requirement
DPP requires: (1) AQE enabled, (2) one side of the join is broadcast-eligible (small table), (3) the join is on the partition column of the large table.
python
# Enable AQE (required for DPP)
spark.conf.set("spark.sql.adaptive.enabled", "true")

# Large fact table partitioned by date
fact_sales = spark.read.parquet("s3://bucket/sales/")
# Partition structure: sales/date=2024-01-01/, sales/date=2024-01-02/, ...

# Small dimension table — just dates we care about
dim_dates = spark.read.parquet("s3://bucket/dim_dates/") \
                 .filter(col("quarter") == "Q4")
# dim_dates now has ONLY Q4 dates

# JOIN: Spark first scans dim_dates, builds a list of Q4 dates,
# THEN only reads the matching partitions from fact_sales
result = fact_sales.join(dim_dates, on="date", how="inner")

result.explain("formatted")
# Look for: "DynamicPruning" in the plan
# This means DPP is working — fact_sales only reads Q4 partitions!

# Config to tune DPP
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.useStats", "true")
✅ Real Impact
DPP can reduce reads on large fact tables by 90%+ in typical star schema queries. A query that previously read 10TB of fact data might only read 500GB after DPP kicks in.
15.5

Partition Strategy Design

Choosing the right partition column is a critical design decision. The wrong choice wastes storage and makes queries slow. The right choice enables dramatic pruning.

📅
Date Partitioning
Most Common
Date Partitioning Strategy
Date is the most common partition column in data engineering. Nearly every analytical query filters by date range. Partitioning by date lets Spark skip all irrelevant time periods.
⚠️ Avoid Partitioning on Raw Timestamp
Never partition on a full timestamp like 2024-01-15 14:32:01 — each unique second becomes its own folder! Extract just year/month/day.
python
from pyspark.sql.functions import year, month, dayofmonth, to_date

# Assume df has a "transaction_timestamp" column

# ---- Strategy 1: Single date column (good for daily queries) ----
df_with_date = df.withColumn("date", to_date(col("transaction_timestamp")))
df_with_date.write.partitionBy("date").parquet("s3://bucket/sales/")
# Creates: sales/date=2024-01-15/, sales/date=2024-01-16/, ...
# 365 folders per year — manageable

# ---- Strategy 2: Year + Month (good for monthly queries) ----
df_ym = df.withColumn("year", year(col("transaction_timestamp"))) \
           .withColumn("month", month(col("transaction_timestamp")))
df_ym.write.partitionBy("year", "month").parquet("s3://bucket/sales/")
# Creates: sales/year=2024/month=1/, sales/year=2024/month=2/, ...
# 12 folders per year, each with month's data

# ---- Strategy 3: Year + Month + Day (most granular) ----
df_ymd = df.withColumn("year", year(col("ts"))) \
            .withColumn("month", month(col("ts"))) \
            .withColumn("day", dayofmonth(col("ts")))
df_ymd.write.partitionBy("year", "month", "day").parquet("output/")
# Best: allows pruning at year, month, OR day level

# Query with pruning at multiple levels
df = spark.read.parquet("s3://bucket/sales/")
df.filter((col("year") == 2024) & (col("month") == 3))  # Only March 2024
🌍
Region / Category Partitioning
Low Cardinality
When to Partition by Region or Category
Partition columns should have low cardinality (few distinct values) — ideally 5–200 distinct values. Region, country, department, status are good candidates. Never partition by high-cardinality columns like user_id (millions of folders!).
python
# ✅ GOOD — Region has 4-5 values
df.write.partitionBy("region").parquet("output/")
# Creates: region=US/, region=EU/, region=APAC/, region=LATAM/
# 4 partitions — lean and efficient

# ✅ GOOD — Product category has 10-20 values
df.write.partitionBy("category").parquet("output/")

# ❌ BAD — user_id has millions of distinct values
df.write.partitionBy("user_id").parquet("output/")
# Creates MILLIONS of tiny folders — kills metadata operations!

# Rule: target partition size 128MB - 1GB each
# If a partition is < 10MB, you have too many distinct values
🔗
Hybrid Partitioning
Best Practice
Combining Multiple Partition Columns
Most production tables use multiple partition columns to enable pruning at multiple levels. The order matters — Spark prunes from left to right in the directory hierarchy.
python
# Hybrid: Date + Region — enables 4 query patterns with pruning
df_ymd.write \
    .partitionBy("year", "month", "region") \
    .parquet("s3://bucket/sales/")

# Directory structure:
# year=2024/month=1/region=US/  ← 4 files
# year=2024/month=1/region=EU/  ← 3 files
# year=2024/month=2/region=US/  ← 4 files
# ...

# Query pattern 1: All of 2024 → prunes to year=2024/ only
df.filter(col("year") == 2024)

# Query pattern 2: Jan 2024 → prunes year + month
df.filter((col("year") == 2024) & (col("month") == 1))

# Query pattern 3: US in Jan 2024 → prunes all three levels
df.filter((col("year") == 2024) &
          (col("month") == 1) &
          (col("region") == "US"))

# Query pattern 4: All US data → Spark still reads all year/month
# but filters region=US at each year/month level
df.filter(col("region") == "US")
# Note: Region without year/month filter = less pruning efficiency
💡 Partition Column Order Matters
Put your most common filter column first. If 90% of queries filter by year first, put year as the first partition column. Spark prunes hierarchically from left to right.
15.6

Skew-Aware Partitioning

Data skew is when some partitions have far more data than others, causing a few slow "straggler" tasks while others finish quickly. These three strategies fix it.

🧂
Salting — Fixing Hot Keys
Most Common Fix
What is Data Skew?
Data skew happens when certain key values appear much more than others. For example, if 80% of your sales records are from "US" and you group by country, the US partition gets 80% of the work. That one task takes 10x longer than others, stalling the whole job.
🍕 Analogy
Imagine dividing a city's pizza orders by borough. Manhattan gets 10,000 orders, Brooklyn gets 8,000, but the Bronx gets 50,000 orders (skew!). One delivery driver (executor) handles 50,000 deliveries while others idle — the job can't finish until that driver is done.
Salting adds a random number (the "salt") to the key, splitting the hot key into multiple sub-keys that can be processed in parallel.
python
from pyspark.sql.functions import col, floor, rand, concat_ws, lit, explode, array

# Problem: "US" key is heavily skewed
df_sales.groupBy("country").count().show()
# country=US: 9,000,000 rows  ← HOT KEY
# country=UK:   500,000 rows
# country=IN:   200,000 rows

# ---- Salting solution ----
SALT_FACTOR = 10  # Split into 10 sub-keys

# Step 1: Add salt to the skewed table
df_salted = df_sales.withColumn(
    "salted_key",
    concat_ws("_", col("country"), (floor(rand() * SALT_FACTOR)).cast("string"))
)
# Now "US" becomes: "US_0", "US_1", ..., "US_9" — 10 sub-partitions

# Step 2: Replicate the small table across salt values
df_lookup = spark.createDataFrame([
    ("US", "United States"), ("UK", "United Kingdom")
], ["country", "country_name"])

salt_range = [str(i) for i in range(SALT_FACTOR)]
df_lookup_salted = df_lookup.withColumn("salt", explode(array(*[lit(s) for s in salt_range]))) \
                             .withColumn("salted_key", concat_ws("_", col("country"), col("salt")))

# Step 3: Join on salted key
result = df_salted.join(df_lookup_salted, on="salted_key", how="inner")

# Step 4: Group by original country (remove salt)
final = result.groupBy("country", "country_name").agg(count("*"))
# Now "US" work is split across 10 tasks in parallel!
📊
Range Partitioning for Skew
Numeric Skew
Using Range Partitioning to Handle Skew
For numeric or ordered data with skew, range partitioning samples the data distribution and creates ranges that each hold roughly the same volume of data — not the same number of distinct keys.
python
# Data: revenue column is skewed
# Most orders are $1-$100, but a few are $1M+ (outliers)

# Hash partitioning on revenue will cause skew
df.repartition(10, "revenue")  # ❌ Most rows in a few partitions

# Range partitioning samples the data and creates even-sized buckets
df.repartitionByRange(10, "revenue")  # ✅ Each partition has similar row count

# Spark internally does quantile estimation:
# partition 0: revenue $0    - $45
# partition 1: revenue $45   - $89
# ...each with ~10% of total rows

# Verify balance
df.repartitionByRange(10, "revenue") \
  .withColumn("pid", spark_partition_id()) \
  .groupBy("pid").count().orderBy("pid").show()
# Should show roughly equal counts per partition
🔁
AQE Skew Join Handling
Automatic
Automatic Skew Handling with AQE
Spark 3.0+ includes Adaptive Query Execution (AQE) which can automatically detect and handle skewed partitions during joins. It splits large skewed partitions into smaller ones and replicates the corresponding partition from the other side.
python
# Enable AQE skew join optimization
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

# Tuning: a partition is "skewed" if it's larger than:
# max(256MB, median_size * skewedPartitionFactor)
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")

# With AQE: Spark automatically splits skewed partitions during joins
# No code changes needed! Just enable the config above.
result = df_large.join(df_small, on="country")
# AQE detects "US" partition is 5x the median → splits it → handles in parallel

# Also: use SKEW hint for explicit control
result = df_large.hint("SKEW", "country") \
                 .join(df_small, on="country")
✅ Recommendation
Always enable AQE (spark.sql.adaptive.enabled=true) in Spark 3.x — it handles skew automatically for most cases. Only resort to manual salting for extreme, persistent skew.
15.7

Custom Partitioners

When built-in hash and range partitioning don't fit your use case, you can write custom partitioning logic at the RDD level.

🛠️
partitionBy() — DataFrame-Level Partitioning
Storage Partition
Writing with partitionBy()
df.write.partitionBy("col") creates physical folder-based partitions on storage. This is for on-disk partitioning, not the in-memory partitioning used during processing. It controls how data is laid out for future reads.
python
from pyspark.sql.functions import year, month

# Basic write-time partition
df.write \
  .partitionBy("year", "month") \
  .mode("overwrite") \
  .parquet("s3://bucket/sales/")

# The "year" and "month" columns are REMOVED from the data files
# They become the folder names instead
# s3://bucket/sales/year=2024/month=3/part-0000.parquet

# Control number of files PER PARTITION using repartition
# Without this, you might get many small files
df.repartition(col("year"), col("month")) \
  .write \
  .partitionBy("year", "month") \
  .mode("overwrite") \
  .parquet("output/")
# Now exactly 1 file per year/month combo (ideal for small-medium data)

# For large data, control files per partition folder
df.repartition(5, col("year"), col("month")) \
  .write \
  .partitionBy("year", "month") \
  .mode("overwrite") \
  .parquet("output/")
# ~5 files per partition folder
⚙️
Custom Partition Logic at RDD Level
Advanced
Writing a Custom Partitioner (RDD API)
The DataFrame API uses hash or range partitioning internally. For truly custom logic (e.g., partition by business rule, alphabetical range, or custom hash), you need to drop to the RDD API and extend Partitioner.
python
from pyspark import SparkContext
from pyspark.rdd import RDD

# Custom partitioner: partition customers by first letter of name
# A-H → partition 0, I-P → partition 1, Q-Z → partition 2

class AlphaPartitioner:
    def __init__(self, num_partitions):
        self.num_partitions = num_partitions

    def __call__(self, key):
        # key is the name (string)
        first_letter = key[0].upper()
        if first_letter <= 'H':
            return 0
        elif first_letter <= 'P':
            return 1
        else:
            return 2

sc = spark.sparkContext
data = [("Alice", 100), ("Bob", 200), ("Zara", 300),
        ("Mike", 150), ("Kate", 250), ("Yuki", 175)]

# Create pair RDD (key, value)
rdd = sc.parallelize(data)  # RDD of (name, amount) tuples

# Apply custom partitioner via partitionBy
partitioner = AlphaPartitioner(3)
rdd_partitioned = rdd.partitionBy(3, partitioner)

# Inspect partitions
for i, partition in enumerate(rdd_partitioned.glom().collect()):
    print(f"Partition {i}: {partition}")
# Partition 0: [('Alice', 100), ('Bob', 200)]       ← A-H
# Partition 1: [('Mike', 150), ('Kate', 250)]        ← I-P
# Partition 2: [('Zara', 300), ('Yuki', 175)]        ← Q-Z
Complete Real-World Example — Balanced Partitioner
A more practical custom partitioner that ensures even distribution based on a business key prefix:
python
from pyspark.sql import functions as F

# Real use case: Route transactions to partitions by transaction_type
# We have: PURCHASE, REFUND, TRANSFER, DEPOSIT, WITHDRAWAL
# Business rule: each type gets its own partition for downstream processing

def business_partitioner(transaction_type):
    mapping = {
        "PURCHASE": 0,
        "REFUND": 1,
        "TRANSFER": 2,
        "DEPOSIT": 3,
        "WITHDRAWAL": 4
    }
    return mapping.get(transaction_type, 0)  # Default to 0 for unknown

sc = spark.sparkContext
NUM_PARTITIONS = 5

# Convert DataFrame to RDD pair (key, whole_row)
rdd = df_transactions.rdd.map(lambda row: (row.transaction_type, row))

# Partition by business rule
rdd_partitioned = rdd.partitionBy(NUM_PARTITIONS, business_partitioner)

# Convert back to DataFrame
df_partitioned = rdd_partitioned.map(lambda kv: kv[1]).toDF(df_transactions.schema)

# Verify: check which rows are in which partition
df_partitioned.withColumn("pid", spark_partition_id()) \
              .groupBy("transaction_type", "pid") \
              .count() \
              .orderBy("pid") \
              .show()
# All PURCHASE → partition 0
# All REFUND   → partition 1
# etc.
⚠️ Custom Partitioner Trade-off
Custom partitioners require dropping to RDD API, losing DataFrame optimizations (Catalyst optimizer). Only use when the business requirement genuinely cannot be expressed with hash/range partitioning.
Module 15 · Reference

Partitioning Cheat Sheet

Quick reference for all Module 15 partition APIs and best practices.

Check Partitions
df.rdd.getNumPartitions()
df.withColumn("pid", spark_partition_id())
.groupBy("pid").count().show()
Repartition Types
df.repartition(N) # round-robin
df.repartition(N, "col") # hash
df.repartitionByRange(N,"col") # range
Coalesce
df.coalesce(N) # narrow, no shuffle
# Only reduces partitions
# Faster than repartition
Partition Pruning
# Enable DPP
spark.conf.set(
"spark.sql.adaptive.enabled","true")
# Use literal filters on partition cols
Write Partitions
df.write
.partitionBy("year","month")
.mode("overwrite")
.parquet("s3://bucket/output/")
Fix Skew (AQE)
spark.conf.set(
"spark.sql.adaptive.enabled","true")
spark.conf.set(
"spark.sql.adaptive.skewJoin.enabled","true")
Shuffle Partitions
# Default = 200 (too high for small data)
spark.conf.set(
"spark.sql.shuffle.partitions","50")
Salting Pattern
SALT = 10
df.withColumn("sk",
concat_ws("_","key",
(floor(rand()*SALT)).cast("string")))
DECISION GUIDE
SituationSolution
Need more partitionsrepartition(N)
Need fewer partitions (before write)coalesce(N)
Join on column — avoid shufflerepartition(N, "join_col") before join
Data skew in joinsEnable AQE skewJoin or use salting
Too many small files on diskcoalesce(5–20) before write
Queries always filter by datepartitionBy("year","month","day") on write
Default 200 shuffle partitions too highSet spark.sql.shuffle.partitions lower
Need sorted outputrepartitionByRange(N, "col")
Partition column has millions of values❌ Choose a lower-cardinality column instead
Module 15 · Practice

Quick Quiz

Test your understanding of Spark Partitioning concepts.

Q1: You have a DataFrame with 200 partitions after a groupBy. Before writing to Parquet, you want just 10 files with minimal overhead. What's the best approach?
Q2: You partition a sales table by (year, month) on disk. A query filters WHERE year = 2024 AND month = 3. What happens?
Q3: df.coalesce(100) is called on a DataFrame that has 4 partitions. What happens?
Q4: Your join on "country" column is very slow. Spark UI shows one task takes 5 minutes, all others finish in 10 seconds. The "US" key has 90% of rows. What is this called and how do you fix it?
Q5: What is the default value of spark.sql.shuffle.partitions and when should you change it?