How Spark Divides and Conquers Data
Partitioning is the single most impactful concept in Spark performance. It controls how your data is split across the cluster, how many parallel tasks run, whether shuffles happen, and how much work each executor does. Master this and you master Spark performance.
Spark Partitions
Understanding the difference between logical and physical partitions, and how Spark decides how many partitions to create.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("partitions").getOrCreate()
# Create a simple DataFrame
data = [(1, "Alice"), (2, "Bob"), (3, "Charlie"),
(4, "Dave"), (5, "Eve")]
df = spark.createDataFrame(data, ["id", "name"])
# Check logical partitions
print(f"Number of logical partitions: {df.rdd.getNumPartitions()}")
# Output: Number of logical partitions: 8 (depends on default parallelism)
Physical: How data is actually stored on disk (files/folders).
Reading 10 Parquet files → 10 physical partitions → Spark creates 10 logical partitions (tasks).
# Reading CSV: each file chunk = one logical partition
df_csv = spark.read.csv("s3://bucket/data/*.csv", header=True)
print(df_csv.rdd.getNumPartitions()) # e.g., 12 (one per file or per 128MB block)
# Reading Parquet: one logical partition per row group
df_parquet = spark.read.parquet("s3://bucket/data/sales.parquet")
print(df_parquet.rdd.getNumPartitions()) # e.g., 4
# Writing physical partitions to disk
df.write.partitionBy("country").parquet("s3://bucket/output/")
# Creates: output/country=US/, output/country=IN/, output/country=UK/ ...
| Source | Default Partition Count | Config |
|---|---|---|
| Reading files (HDFS/S3) | One per 128MB block | spark.sql.files.maxPartitionBytes |
| After shuffle (groupBy, join) | 200 | spark.sql.shuffle.partitions |
| parallelize() from collection | Default parallelism | spark.default.parallelism |
| repartition(N) | N (you specify) | N/A |
# Default shuffle partitions = 200. Change based on data size:
spark.conf.set("spark.sql.shuffle.partitions", "50")
# For small data: reduce to 50 (avoids 200 tiny tasks)
# For large data (TB scale): increase to 1000+
# See current value
print(spark.conf.get("spark.sql.shuffle.partitions"))
# Check partition count after groupBy (triggers shuffle)
result = df.groupBy("country").count()
print(result.rdd.getNumPartitions()) # Will be 200 by default
# Rule of thumb: partition size = total data size / num partitions
# Target: 128MB - 256MB per partition
# Example: 10GB data → ~50 to 80 partitions
Repartition
repartition() is a wide transformation that performs a full shuffle to redistribute data evenly across a new number of partitions.
repartition(). Spark computes a hash of one or more column values and uses modulo arithmetic to assign each row to a partition. This ensures rows with the same key always go to the same partition.
from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id
spark = SparkSession.builder.appName("repartition").getOrCreate()
data = [(1, "Alice", "US"), (2, "Bob", "UK"), (3, "Charlie", "US"),
(4, "Dave", "IN"), (5, "Eve", "UK"), (6, "Frank", "US")]
df = spark.createDataFrame(data, ["id", "name", "country"])
# ---- Hash repartition by number only (round-robin hash) ----
df_repartitioned = df.repartition(4)
print(df_repartitioned.rdd.getNumPartitions()) # 4
# ---- Hash repartition by column ----
# All rows with same "country" go to the same partition
df_by_country = df.repartition(4, "country")
# See which partition each row landed in
df_by_country.withColumn("partition_id", spark_partition_id()) \
.orderBy("country") \
.show()
# Output: All "US" rows → same partition_id
# All "UK" rows → same partition_id
# All "IN" rows → same partition_id
# ---- Multi-column hash repartition ----
df_multi = df.repartition(8, "country", "name")
# Rows with same (country, name) combo → same partition
repartitionByRange()) samples the data first to understand its distribution, then divides it into ranges so each partition holds a contiguous range of values. The output partitions will be sorted within each partition.
from pyspark.sql.functions import spark_partition_id, col
data = [(1, 100), (2, 200), (3, 50), (4, 400), (5, 350),
(6, 75), (7, 275), (8, 450)]
df = spark.createDataFrame(data, ["id", "amount"])
# Range repartition: divides "amount" into 3 sorted ranges
df_range = df.repartitionByRange(3, "amount")
df_range.withColumn("partition_id", spark_partition_id()) \
.orderBy("amount") \
.show()
# Approximate output:
# partition 0: amount 50, 75, 100 ← low values
# partition 1: amount 200, 275 ← mid values
# partition 2: amount 350, 400, 450 ← high values
# Also supports descending order
df_range_desc = df.repartitionByRange(3, col("amount").desc())
# Great before writing sorted data to files
df_range.write.parquet("s3://bucket/sorted-output/")
# Each file will have sorted amounts → better predicate pushdown on reads
Range: Groups ranges of values. Best for ordered/sorted output and range queries.
repartition(N) without specifying a column, Spark uses round-robin — it distributes rows sequentially across partitions in order. Row 1 → partition 0, Row 2 → partition 1, ..., Row N+1 → partition 0 again. This creates perfectly balanced partitions but doesn't group any related rows together.
# Round robin — NO column specified
# Rows distributed evenly, no grouping by key
df_rr = df.repartition(4) # 4 partitions, balanced row count
# Check balance across partitions
df_rr.withColumn("partition_id", spark_partition_id()) \
.groupBy("partition_id") \
.count() \
.orderBy("partition_id") \
.show()
# Output: Each partition has ~equal number of rows
# +------------+-----+
# |partition_id|count|
# +------------+-----+
# | 0| 2|
# | 1| 1|
# | 2| 2|
# | 3| 1|
# +------------+-----+
# Use case: fixing skewed partitions before a write
df_skewed.repartition(100).write.parquet("output/")
# Creates 100 balanced files on disk
| Method | Column Required? | Groups Related Rows? | Best For |
|---|---|---|---|
repartition(N) | No | No (round-robin) | Balancing row counts |
repartition(N, col) | Yes | Yes (hash) | Joins, aggregations |
repartitionByRange(N, col) | Yes | Yes (range) | Sorted writes, range scans |
Coalesce
coalesce() reduces the number of partitions without a full shuffle — making it much cheaper than repartition() when you only need fewer partitions.
# Start with many partitions
df_big = spark.range(1000000).repartition(100)
print(df_big.rdd.getNumPartitions()) # 100
# Coalesce down to 10 — NO shuffle, just merges partitions
df_small = df_big.coalesce(10)
print(df_small.rdd.getNumPartitions()) # 10
# Classic use case: before writing to avoid many small files
df_result = df.groupBy("country").count()
# After groupBy, shuffle partitions = 200 (many empty!)
df_result.coalesce(5).write.parquet("output/")
# Creates only 5 output files, not 200
| Feature | coalesce() | repartition() |
|---|---|---|
| Shuffle | ❌ No shuffle | ✅ Full shuffle |
| Speed | Fast | Slower |
| Can increase partitions? | ❌ No (only decrease) | ✅ Yes |
| Data balance | May be uneven | Even distribution |
| Transformation type | Narrow | Wide |
| Best for | Reducing before write | Redistribution, fixing skew |
# ✅ GOOD: coalesce when reducing partitions (no shuffle needed)
df.coalesce(5).write.parquet("output/")
# ❌ WRONG: coalesce cannot increase partitions
df_4_partitions.coalesce(100) # Still stays at 4! Ignored.
# ✅ GOOD: repartition when increasing or needing balanced distribution
df_4_partitions.repartition(100) # Now has 100 balanced partitions
# ✅ GOOD: repartition when fixing skew
df_skewed.repartition(200).write.parquet("output/")
# Rule: "Want fewer? Use coalesce. Want more or need balance? Use repartition."
Partition Pruning
Partition pruning is Spark's ability to skip reading partitions that can't contain the data you need. This is the biggest performance win from proper partitioning.
year=2022/, year=2023/, year=2024/. You query WHERE year = 2024. Static pruning means Spark never even opens the 2022 and 2023 folders. It skips 66% of the data before reading a single byte!
# Write partitioned data to S3/HDFS
df_sales.write \
.partitionBy("year", "month") \
.parquet("s3://bucket/sales/")
# On disk, it creates:
# s3://bucket/sales/year=2022/month=01/data.parquet
# s3://bucket/sales/year=2022/month=02/data.parquet
# s3://bucket/sales/year=2023/month=01/data.parquet
# ... etc.
# Now read with a LITERAL filter → Static Pruning kicks in
df = spark.read.parquet("s3://bucket/sales/")
# ONLY reads year=2024/ folders — skips all other years!
df_2024 = df.filter(col("year") == 2024)
# Even better — multi-level pruning
# Only reads year=2024/month=03/ — one folder!
df_march = df.filter((col("year") == 2024) & (col("month") == 3))
# Verify pruning happened in the explain plan
df_march.explain("formatted")
# Look for: "PartitionFilters: [(year = 2024), (month = 3)]"
# This means pruning is working!
# Enable AQE (required for DPP)
spark.conf.set("spark.sql.adaptive.enabled", "true")
# Large fact table partitioned by date
fact_sales = spark.read.parquet("s3://bucket/sales/")
# Partition structure: sales/date=2024-01-01/, sales/date=2024-01-02/, ...
# Small dimension table — just dates we care about
dim_dates = spark.read.parquet("s3://bucket/dim_dates/") \
.filter(col("quarter") == "Q4")
# dim_dates now has ONLY Q4 dates
# JOIN: Spark first scans dim_dates, builds a list of Q4 dates,
# THEN only reads the matching partitions from fact_sales
result = fact_sales.join(dim_dates, on="date", how="inner")
result.explain("formatted")
# Look for: "DynamicPruning" in the plan
# This means DPP is working — fact_sales only reads Q4 partitions!
# Config to tune DPP
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.useStats", "true")
Partition Strategy Design
Choosing the right partition column is a critical design decision. The wrong choice wastes storage and makes queries slow. The right choice enables dramatic pruning.
2024-01-15 14:32:01 — each unique second becomes its own folder! Extract just year/month/day.
from pyspark.sql.functions import year, month, dayofmonth, to_date
# Assume df has a "transaction_timestamp" column
# ---- Strategy 1: Single date column (good for daily queries) ----
df_with_date = df.withColumn("date", to_date(col("transaction_timestamp")))
df_with_date.write.partitionBy("date").parquet("s3://bucket/sales/")
# Creates: sales/date=2024-01-15/, sales/date=2024-01-16/, ...
# 365 folders per year — manageable
# ---- Strategy 2: Year + Month (good for monthly queries) ----
df_ym = df.withColumn("year", year(col("transaction_timestamp"))) \
.withColumn("month", month(col("transaction_timestamp")))
df_ym.write.partitionBy("year", "month").parquet("s3://bucket/sales/")
# Creates: sales/year=2024/month=1/, sales/year=2024/month=2/, ...
# 12 folders per year, each with month's data
# ---- Strategy 3: Year + Month + Day (most granular) ----
df_ymd = df.withColumn("year", year(col("ts"))) \
.withColumn("month", month(col("ts"))) \
.withColumn("day", dayofmonth(col("ts")))
df_ymd.write.partitionBy("year", "month", "day").parquet("output/")
# Best: allows pruning at year, month, OR day level
# Query with pruning at multiple levels
df = spark.read.parquet("s3://bucket/sales/")
df.filter((col("year") == 2024) & (col("month") == 3)) # Only March 2024
# ✅ GOOD — Region has 4-5 values
df.write.partitionBy("region").parquet("output/")
# Creates: region=US/, region=EU/, region=APAC/, region=LATAM/
# 4 partitions — lean and efficient
# ✅ GOOD — Product category has 10-20 values
df.write.partitionBy("category").parquet("output/")
# ❌ BAD — user_id has millions of distinct values
df.write.partitionBy("user_id").parquet("output/")
# Creates MILLIONS of tiny folders — kills metadata operations!
# Rule: target partition size 128MB - 1GB each
# If a partition is < 10MB, you have too many distinct values
# Hybrid: Date + Region — enables 4 query patterns with pruning
df_ymd.write \
.partitionBy("year", "month", "region") \
.parquet("s3://bucket/sales/")
# Directory structure:
# year=2024/month=1/region=US/ ← 4 files
# year=2024/month=1/region=EU/ ← 3 files
# year=2024/month=2/region=US/ ← 4 files
# ...
# Query pattern 1: All of 2024 → prunes to year=2024/ only
df.filter(col("year") == 2024)
# Query pattern 2: Jan 2024 → prunes year + month
df.filter((col("year") == 2024) & (col("month") == 1))
# Query pattern 3: US in Jan 2024 → prunes all three levels
df.filter((col("year") == 2024) &
(col("month") == 1) &
(col("region") == "US"))
# Query pattern 4: All US data → Spark still reads all year/month
# but filters region=US at each year/month level
df.filter(col("region") == "US")
# Note: Region without year/month filter = less pruning efficiency
Skew-Aware Partitioning
Data skew is when some partitions have far more data than others, causing a few slow "straggler" tasks while others finish quickly. These three strategies fix it.
from pyspark.sql.functions import col, floor, rand, concat_ws, lit, explode, array
# Problem: "US" key is heavily skewed
df_sales.groupBy("country").count().show()
# country=US: 9,000,000 rows ← HOT KEY
# country=UK: 500,000 rows
# country=IN: 200,000 rows
# ---- Salting solution ----
SALT_FACTOR = 10 # Split into 10 sub-keys
# Step 1: Add salt to the skewed table
df_salted = df_sales.withColumn(
"salted_key",
concat_ws("_", col("country"), (floor(rand() * SALT_FACTOR)).cast("string"))
)
# Now "US" becomes: "US_0", "US_1", ..., "US_9" — 10 sub-partitions
# Step 2: Replicate the small table across salt values
df_lookup = spark.createDataFrame([
("US", "United States"), ("UK", "United Kingdom")
], ["country", "country_name"])
salt_range = [str(i) for i in range(SALT_FACTOR)]
df_lookup_salted = df_lookup.withColumn("salt", explode(array(*[lit(s) for s in salt_range]))) \
.withColumn("salted_key", concat_ws("_", col("country"), col("salt")))
# Step 3: Join on salted key
result = df_salted.join(df_lookup_salted, on="salted_key", how="inner")
# Step 4: Group by original country (remove salt)
final = result.groupBy("country", "country_name").agg(count("*"))
# Now "US" work is split across 10 tasks in parallel!
# Data: revenue column is skewed
# Most orders are $1-$100, but a few are $1M+ (outliers)
# Hash partitioning on revenue will cause skew
df.repartition(10, "revenue") # ❌ Most rows in a few partitions
# Range partitioning samples the data and creates even-sized buckets
df.repartitionByRange(10, "revenue") # ✅ Each partition has similar row count
# Spark internally does quantile estimation:
# partition 0: revenue $0 - $45
# partition 1: revenue $45 - $89
# ...each with ~10% of total rows
# Verify balance
df.repartitionByRange(10, "revenue") \
.withColumn("pid", spark_partition_id()) \
.groupBy("pid").count().orderBy("pid").show()
# Should show roughly equal counts per partition
# Enable AQE skew join optimization
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# Tuning: a partition is "skewed" if it's larger than:
# max(256MB, median_size * skewedPartitionFactor)
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
# With AQE: Spark automatically splits skewed partitions during joins
# No code changes needed! Just enable the config above.
result = df_large.join(df_small, on="country")
# AQE detects "US" partition is 5x the median → splits it → handles in parallel
# Also: use SKEW hint for explicit control
result = df_large.hint("SKEW", "country") \
.join(df_small, on="country")
spark.sql.adaptive.enabled=true) in Spark 3.x — it handles skew automatically for most cases. Only resort to manual salting for extreme, persistent skew.
Custom Partitioners
When built-in hash and range partitioning don't fit your use case, you can write custom partitioning logic at the RDD level.
df.write.partitionBy("col") creates physical folder-based partitions on storage. This is for on-disk partitioning, not the in-memory partitioning used during processing. It controls how data is laid out for future reads.
from pyspark.sql.functions import year, month
# Basic write-time partition
df.write \
.partitionBy("year", "month") \
.mode("overwrite") \
.parquet("s3://bucket/sales/")
# The "year" and "month" columns are REMOVED from the data files
# They become the folder names instead
# s3://bucket/sales/year=2024/month=3/part-0000.parquet
# Control number of files PER PARTITION using repartition
# Without this, you might get many small files
df.repartition(col("year"), col("month")) \
.write \
.partitionBy("year", "month") \
.mode("overwrite") \
.parquet("output/")
# Now exactly 1 file per year/month combo (ideal for small-medium data)
# For large data, control files per partition folder
df.repartition(5, col("year"), col("month")) \
.write \
.partitionBy("year", "month") \
.mode("overwrite") \
.parquet("output/")
# ~5 files per partition folder
from pyspark import SparkContext
from pyspark.rdd import RDD
# Custom partitioner: partition customers by first letter of name
# A-H → partition 0, I-P → partition 1, Q-Z → partition 2
class AlphaPartitioner:
def __init__(self, num_partitions):
self.num_partitions = num_partitions
def __call__(self, key):
# key is the name (string)
first_letter = key[0].upper()
if first_letter <= 'H':
return 0
elif first_letter <= 'P':
return 1
else:
return 2
sc = spark.sparkContext
data = [("Alice", 100), ("Bob", 200), ("Zara", 300),
("Mike", 150), ("Kate", 250), ("Yuki", 175)]
# Create pair RDD (key, value)
rdd = sc.parallelize(data) # RDD of (name, amount) tuples
# Apply custom partitioner via partitionBy
partitioner = AlphaPartitioner(3)
rdd_partitioned = rdd.partitionBy(3, partitioner)
# Inspect partitions
for i, partition in enumerate(rdd_partitioned.glom().collect()):
print(f"Partition {i}: {partition}")
# Partition 0: [('Alice', 100), ('Bob', 200)] ← A-H
# Partition 1: [('Mike', 150), ('Kate', 250)] ← I-P
# Partition 2: [('Zara', 300), ('Yuki', 175)] ← Q-Z
from pyspark.sql import functions as F
# Real use case: Route transactions to partitions by transaction_type
# We have: PURCHASE, REFUND, TRANSFER, DEPOSIT, WITHDRAWAL
# Business rule: each type gets its own partition for downstream processing
def business_partitioner(transaction_type):
mapping = {
"PURCHASE": 0,
"REFUND": 1,
"TRANSFER": 2,
"DEPOSIT": 3,
"WITHDRAWAL": 4
}
return mapping.get(transaction_type, 0) # Default to 0 for unknown
sc = spark.sparkContext
NUM_PARTITIONS = 5
# Convert DataFrame to RDD pair (key, whole_row)
rdd = df_transactions.rdd.map(lambda row: (row.transaction_type, row))
# Partition by business rule
rdd_partitioned = rdd.partitionBy(NUM_PARTITIONS, business_partitioner)
# Convert back to DataFrame
df_partitioned = rdd_partitioned.map(lambda kv: kv[1]).toDF(df_transactions.schema)
# Verify: check which rows are in which partition
df_partitioned.withColumn("pid", spark_partition_id()) \
.groupBy("transaction_type", "pid") \
.count() \
.orderBy("pid") \
.show()
# All PURCHASE → partition 0
# All REFUND → partition 1
# etc.
Partitioning Cheat Sheet
Quick reference for all Module 15 partition APIs and best practices.
df.withColumn("pid", spark_partition_id())
.groupBy("pid").count().show()
df.repartition(N, "col") # hash
df.repartitionByRange(N,"col") # range
# Only reduces partitions
# Faster than repartition
spark.conf.set(
"spark.sql.adaptive.enabled","true")
# Use literal filters on partition cols
.partitionBy("year","month")
.mode("overwrite")
.parquet("s3://bucket/output/")
"spark.sql.adaptive.enabled","true")
spark.conf.set(
"spark.sql.adaptive.skewJoin.enabled","true")
spark.conf.set(
"spark.sql.shuffle.partitions","50")
df.withColumn("sk",
concat_ws("_","key",
(floor(rand()*SALT)).cast("string")))
| Situation | Solution |
|---|---|
| Need more partitions | repartition(N) |
| Need fewer partitions (before write) | coalesce(N) |
| Join on column — avoid shuffle | repartition(N, "join_col") before join |
| Data skew in joins | Enable AQE skewJoin or use salting |
| Too many small files on disk | coalesce(5–20) before write |
| Queries always filter by date | partitionBy("year","month","day") on write |
| Default 200 shuffle partitions too high | Set spark.sql.shuffle.partitions lower |
| Need sorted output | repartitionByRange(N, "col") |
| Partition column has millions of values | ❌ Choose a lower-cardinality column instead |
Quick Quiz
Test your understanding of Spark Partitioning concepts.