What is an RDD?
RDD — Resilient Distributed Dataset — is the fundamental data structure of Apache Spark. Everything in Spark is built on top of RDDs. Understanding RDDs gives you deep insight into how Spark actually works under the hood.
R — Resilient: Fault-tolerant. If a partition is lost (executor dies), Spark can recompute it using the lineage graph.
D — Distributed: Data is split across many machines (executors) in a cluster.
D — Dataset: A collection of data elements — similar to a Python list, but distributed.
Every RDD has exactly five characteristics that define its behavior:
| # | Property | What it means | Benefit |
|---|---|---|---|
| 1 | List of Partitions | Data is split into chunks (partitions) | Enables parallelism — each partition processed by one task |
| 2 | Function to compute each partition | Logic to compute data from parent RDD | Forms the lineage / DAG |
| 3 | List of dependencies | Which parent RDDs it depends on | Enables fault recovery (recompute lost partitions) |
| 4 | Optionally: Partitioner | For key-value RDDs — how keys are distributed | Enables optimized joins and groupBy |
| 5 | Optionally: Preferred locations | Where to place each partition (data locality) | Reduces network I/O — compute near data |
Transformations on RDDs are lazy — they only describe what to do, they don't execute. Spark builds a DAG (Directed Acyclic Graph) of operations. Execution only happens when you call an Action.
RDDs are immutable — once created, you cannot change the data inside them. Every transformation creates a new RDD. This is exactly what makes fault tolerance easy: if a partition is lost, Spark knows the exact chain of transformations that created it and can recompute it.
rdd = rdd.map(...) — this doesn't modify the original RDD. It creates a brand new RDD and reassigns the Python variable. The old RDD still exists (until garbage collected).
Spark has three APIs. Knowing when to use each is important:
| API | Type Safety | Schema | Optimizer | Best For |
|---|---|---|---|---|
| RDD | No | No schema | No Catalyst | Low-level control, custom serialization, legacy code |
| DataFrame | Runtime | Yes | Catalyst + Tungsten | Most production use cases — fast and optimized |
| Dataset | Compile-time | Yes | Catalyst + Tungsten | Scala/Java type safety — not available in Python |
Creating RDDs
There are three main ways to create an RDD: from a Python collection using parallelize(), from a file on disk, and from an existing RDD via transformation. Let's cover all of them with code.
sc.parallelize(collection) takes any Python iterable (list, range, tuple) and distributes it across partitions in the cluster. This is the most common way to create an RDD during learning and testing.
parallelize splits the letters evenly — 25 letters per worker — so they all deliver in parallel.
from pyspark import SparkContext
from pyspark.sql import SparkSession
# Create SparkSession (SparkContext is available as sc via spark.sparkContext)
spark = SparkSession.builder.appName("RDD Demo").master("local[*]").getOrCreate()
sc = spark.sparkContext # SparkContext — the entry point for RDD operations
# ── Method 1a: From a simple list ──
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
rdd = sc.parallelize(numbers)
print(rdd.collect()) # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
print(rdd.count()) # 10
print(rdd.getNumPartitions()) # depends on cluster cores
# ── Method 1b: Control the number of partitions ──
rdd_4p = sc.parallelize(numbers, 4) # force 4 partitions
print(rdd_4p.getNumPartitions()) # 4
# ── Method 1c: From a list of tuples (key-value pairs) ──
data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)]
rdd_kv = sc.parallelize(data)
print(rdd_kv.collect())
# [('Alice', 30), ('Bob', 25), ('Charlie', 35)]
# ── Method 1d: From a range ──
rdd_range = sc.parallelize(range(1, 101)) # numbers 1 to 100
print(rdd_range.sum()) # 5050
local[*], it uses all available cores. You can always override with the second argument: sc.parallelize(data, numPartitions).
sc.textFile(path) reads a text file and creates an RDD where each line becomes one element. Works with local files, HDFS paths, S3 paths, or wildcards.
# ── Reading a local file ──
rdd_text = sc.textFile("data/employees.txt")
# Each element = one line of the file
print(rdd_text.first()) # First line
print(rdd_text.take(5)) # First 5 lines
print(rdd_text.count()) # Number of lines
# ── Reading from HDFS ──
rdd_hdfs = sc.textFile("hdfs://namenode:9000/data/logs/*.csv")
# ── Reading from S3 ──
rdd_s3 = sc.textFile("s3://my-bucket/data/transactions.txt")
# ── Reading with minPartitions hint ──
rdd_min = sc.textFile("data/big_file.txt", 8) # at least 8 partitions
# ── Practical example: Word count from a file ──
lines = sc.textFile("data/book.txt")
words = lines.flatMap(lambda line: line.split(" "))
word_count = words.map(lambda w: (w, 1)).reduceByKey(lambda a, b: a + b)
print(word_count.take(10))
# ── wholeTextFiles: filename + content as tuple ──
rdd_whole = sc.wholeTextFiles("data/logs/")
# Each element: ("file:///data/logs/file1.txt", "content of file1")
for filename, content in rdd_whole.collect():
print(f"File: {filename}, Lines: {len(content.splitlines())}")
Spark's SparkContext offers several more ways to create RDDs from files:
# Sequence files (Hadoop format)
rdd_seq = sc.sequenceFile("hdfs://path/to/seqfile")
# Binary files → (filename, binary_content)
rdd_bin = sc.binaryFiles("data/images/*.jpg")
# Pickled Python objects
rdd_pkl = sc.pickleFile("data/pickled_rdd")
# Reading CSV manually via textFile (before Spark DataFrames were available)
rdd_csv = sc.textFile("data/sales.csv")
header = rdd_csv.first() # Get header line
data = rdd_csv.filter(lambda line: line != header).map(lambda line: line.split(","))
# data is now an RDD of lists: [["Alice", "100"], ["Bob", "200"], ...]
You can convert between DataFrames and RDDs. This is useful when you need fine-grained RDD control for part of your pipeline but want the efficiency of DataFrames elsewhere.
# ── DataFrame → RDD ──
df = spark.createDataFrame([("Alice", 30), ("Bob", 25)], ["name", "age"])
rdd_from_df = df.rdd # RDD of Row objects
print(rdd_from_df.collect())
# [Row(name='Alice', age=30), Row(name='Bob', age=25)]
# Access Row fields
names = rdd_from_df.map(lambda row: row.name)
print(names.collect()) # ['Alice', 'Bob']
# ── RDD → DataFrame ──
rdd_data = sc.parallelize([("Alice", 30), ("Bob", 25)])
df_from_rdd = spark.createDataFrame(rdd_data, ["name", "age"])
df_from_rdd.show()
# +-------+---+
# | name|age|
# +-------+---+
# | Alice| 30|
# | Bob| 25|
# +-------+---+
# ── Using toDF() shortcut ──
df2 = rdd_data.toDF(["name", "age"])
df2.show()
RDD Transformations
Transformations are lazy operations that produce a new RDD from an existing one. No computation happens until an action is called. Master every transformation — they form the backbone of Spark processing.
map(func) applies a function to every element in the RDD and returns a new RDD of the same size. One element in → one element out.
numbers = sc.parallelize([1, 2, 3, 4, 5])
# Square every number
squared = numbers.map(lambda x: x ** 2)
print(squared.collect()) # [1, 4, 9, 16, 25]
# Using a named function
def to_upper(text):
return text.upper()
words = sc.parallelize(["hello", "world", "spark"])
upper_words = words.map(to_upper)
print(upper_words.collect()) # ['HELLO', 'WORLD', 'SPARK']
# map to create tuples (key-value pairs)
employees = sc.parallelize(["Alice,Engineering,90000", "Bob,Marketing,70000"])
parsed = employees.map(lambda line: line.split(","))
print(parsed.collect())
# [['Alice', 'Engineering', '90000'], ['Bob', 'Marketing', '70000']]
flatMap(func) is like map but each element can produce zero or more output elements. The results are flattened into a single RDD.
# The classic word count example
sentences = sc.parallelize([
"Hello World",
"Apache Spark is fast",
"PySpark is Python"
])
# map → list of lists (NOT flat)
mapped = sentences.map(lambda s: s.split(" "))
print(mapped.collect())
# [['Hello', 'World'], ['Apache', 'Spark', 'is', 'fast'], ['PySpark', 'is', 'Python']]
# Count: 3 elements (each is a list)
# flatMap → flat list of words
words = sentences.flatMap(lambda s: s.split(" "))
print(words.collect())
# ['Hello', 'World', 'Apache', 'Spark', 'is', 'fast', 'PySpark', 'is', 'Python']
# Count: 9 individual words — fully flattened!
# flatMap can also return 0 elements (filter out)
numbers = sc.parallelize([1, 2, 3, 4, 5])
result = numbers.flatMap(lambda x: [x, x*10] if x % 2 == 0 else [])
print(result.collect()) # [2, 20, 4, 40] — odd numbers excluded, evens doubled
filter(func) keeps only elements where func returns True. The resulting RDD has fewer or equal elements.
numbers = sc.parallelize(range(1, 11))
# Keep only even numbers
evens = numbers.filter(lambda x: x % 2 == 0)
print(evens.collect()) # [2, 4, 6, 8, 10]
# Filter on string RDD
logs = sc.parallelize([
"INFO: Server started",
"ERROR: Connection failed",
"INFO: Request received",
"ERROR: Timeout",
"WARN: Low memory"
])
errors = logs.filter(lambda line: line.startswith("ERROR"))
print(errors.collect())
# ['ERROR: Connection failed', 'ERROR: Timeout']
# Chaining transformations (all lazy)
result = (numbers
.filter(lambda x: x % 2 == 0) # keep evens: 2,4,6,8,10
.map(lambda x: x ** 2) # square them: 4,16,36,64,100
.filter(lambda x: x > 20) # keep > 20: 36, 64, 100
)
print(result.collect()) # [36, 64, 100]
Returns a new RDD with duplicate elements removed. Internally triggers a shuffle.
rdd = sc.parallelize([1, 2, 2, 3, 3, 3, 4])
print(rdd.distinct().collect()) # [1, 2, 3, 4] (order may vary)
# With string data
colors = sc.parallelize(["red", "blue", "red", "green", "blue"])
print(colors.distinct().collect()) # ['red', 'blue', 'green']
rdd1 = sc.parallelize([1, 2, 3, 4])
rdd2 = sc.parallelize([3, 4, 5, 6])
# union — combine both RDDs (keeps duplicates)
print(rdd1.union(rdd2).collect())
# [1, 2, 3, 4, 3, 4, 5, 6] ← duplicates kept!
# union then distinct — unique elements from both
print(rdd1.union(rdd2).distinct().collect())
# [1, 2, 3, 4, 5, 6]
# intersection — only elements in BOTH RDDs (triggers shuffle)
print(rdd1.intersection(rdd2).collect())
# [3, 4]
# subtract — elements in rdd1 but NOT in rdd2
print(rdd1.subtract(rdd2).collect())
# [1, 2]
# cartesian — every combination of elements (cross product)
a = sc.parallelize([1, 2])
b = sc.parallelize(["a", "b", "c"])
print(a.cartesian(b).collect())
# [(1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')]
# ⚠️ Warning: cartesian is very expensive on large data — avoid in production!
groupByKey() groups all values for the same key together into an iterable. It shuffles ALL data over the network — very expensive and memory-intensive.
sales = sc.parallelize([
("Alice", 100), ("Bob", 200),
("Alice", 300), ("Bob", 150),
("Charlie", 400)
])
# groupByKey — groups VALUES by key (all shuffled to one executor per key)
grouped = sales.groupByKey()
for name, values in grouped.collect():
print(f"{name}: {list(values)}")
# Alice: [100, 300]
# Bob: [200, 150]
# Charlie: [400]
# If you need sum, use reduceByKey instead:
total_sales = sales.reduceByKey(lambda a, b: a + b)
print(total_sales.collect())
# [('Alice', 400), ('Bob', 350), ('Charlie', 400)]
reduceByKey(func) combines values for the same key using an associative and commutative function. It runs a local pre-aggregation on each partition first, then shuffles only the combined values — much more efficient than groupByKey.
# Word count — the classic example
text = sc.parallelize(["spark is fast spark is great spark"])
word_counts = (text
.flatMap(lambda line: line.split(" "))
.map(lambda w: (w, 1))
.reduceByKey(lambda a, b: a + b)
)
print(word_counts.collect())
# [('spark', 3), ('is', 2), ('fast', 1), ('great', 1)]
# Sales total per region
regional = sc.parallelize([
("North", 500), ("South", 300),
("North", 700), ("East", 400),
("South", 200), ("East", 600)
])
totals = regional.reduceByKey(lambda a, b: a + b)
print(totals.collect())
# [('North', 1200), ('South', 500), ('East', 1000)]
aggregateByKey(zeroValue, seqFunc, combFunc) is the most flexible key aggregation. Use it when the intermediate accumulator type differs from the input values (e.g., computing average: you need sum AND count, not just sum).
# Compute average salary per department
# Problem: reduceByKey can't do averages directly
# Solution: accumulate (sum, count) then divide
scores = sc.parallelize([
("Math", 90), ("Math", 80), ("Math", 70),
("Science", 95), ("Science", 85)
])
# zeroValue: initial accumulator (sum=0, count=0)
zero = (0, 0)
# seqFunc: how to combine a new value into the partition accumulator
def seq_func(acc, val):
return (acc[0] + val, acc[1] + 1) # (sum + val, count + 1)
# combFunc: how to merge two partition accumulators during shuffle
def comb_func(acc1, acc2):
return (acc1[0] + acc2[0], acc1[1] + acc2[1])
result = scores.aggregateByKey(zero, seq_func, comb_func)
print(result.collect())
# [('Math', (240, 3)), ('Science', (180, 2))]
# Now compute the actual average
avg = result.map(lambda kv: (kv[0], kv[1][0] / kv[1][1]))
print(avg.collect())
# [('Math', 80.0), ('Science', 90.0)]
combineByKey(createCombiner, mergeValue, mergeCombiner) is the most powerful and general key-aggregation function. aggregateByKey is actually implemented using combineByKey internally.
# Same average salary example using combineByKey
salaries = sc.parallelize([
("Eng", 90000), ("Eng", 80000),
("Mkt", 70000), ("Mkt", 75000), ("Mkt", 65000)
])
# createCombiner: called for the FIRST value seen for a key in a partition
def create_combiner(val):
return (val, 1) # (sum, count)
# mergeValue: add subsequent values to existing combiner in same partition
def merge_value(combiner, val):
return (combiner[0] + val, combiner[1] + 1)
# mergeCombiner: merge combiners from different partitions during shuffle
def merge_combiners(c1, c2):
return (c1[0] + c2[0], c1[1] + c2[1])
result = salaries.combineByKey(create_combiner, merge_value, merge_combiners)
avg_salary = result.map(lambda kv: (kv[0], kv[1][0] / kv[1][1]))
print(avg_salary.collect())
# [('Eng', 85000.0), ('Mkt', 70000.0)]
sortByKey(ascending=True) sorts key-value pairs by key. You can sort descending by setting ascending=False.
sales = sc.parallelize([
("Charlie", 400), ("Alice", 100), ("Bob", 300)
])
# Sort alphabetically by name (ascending)
print(sales.sortByKey().collect())
# [('Alice', 100), ('Bob', 300), ('Charlie', 400)]
# Sort descending
print(sales.sortByKey(ascending=False).collect())
# [('Charlie', 400), ('Bob', 300), ('Alice', 100)]
# Sort by value (swap key-value, sort, swap back)
by_value = sales.map(lambda kv: (kv[1], kv[0])).sortByKey(ascending=False)
print(by_value.collect())
# [(400, 'Charlie'), (300, 'Bob'), (100, 'Alice')]
RDD Actions
Actions trigger actual computation and either return a value to the driver or write data to storage. Every action causes the entire DAG of transformations to execute.
collect() returns ALL elements of the RDD as a Python list to the driver program. Only use on small data — if your RDD has millions of rows, this will OOM your driver.
collect() on a large RDD in production. Use take(n), write, or aggregate the data first.rdd = sc.parallelize([1, 2, 3, 4, 5])
result = rdd.collect()
print(result) # [1, 2, 3, 4, 5] — Python list on driver
print(type(result)) # <class 'list'>
rdd = sc.parallelize([30, 10, 50, 20, 40])
# count() — number of elements
print(rdd.count()) # 5
# first() — first element (same as take(1)[0])
print(rdd.first()) # 30
# take(n) — first n elements (no particular order guaranteed without sort)
print(rdd.take(3)) # [30, 10, 50]
# top(n) — n largest elements (sorted descending)
print(rdd.top(3)) # [50, 40, 30]
# takeOrdered(n) — n smallest elements (sorted ascending)
print(rdd.takeOrdered(3)) # [10, 20, 30]
# takeOrdered with custom key — reverse sorted (same as top)
print(rdd.takeOrdered(3, key=lambda x: -x)) # [50, 40, 30]
reduce(func) aggregates the elements of the RDD using an associative and commutative binary function. The function takes two arguments and returns one.
numbers = sc.parallelize([1, 2, 3, 4, 5])
# Sum of all numbers
total = numbers.reduce(lambda a, b: a + b)
print(total) # 15
# Maximum value
max_val = numbers.reduce(lambda a, b: a if a > b else b)
print(max_val) # 5
# Product of all numbers
product = numbers.reduce(lambda a, b: a * b)
print(product) # 120 (1×2×3×4×5)
# Shortcut methods (faster than reduce for these common cases)
print(numbers.sum()) # 15
print(numbers.min()) # 1
print(numbers.max()) # 5
print(numbers.mean()) # 3.0
print(numbers.stdev()) # standard deviation
print(numbers.variance()) # variance
foreach(func) applies a function to each element. Unlike map, it doesn't return anything. Used for side-effects like writing to a database or logging.
# Print each element (runs on executors, output goes to executor logs)
rdd = sc.parallelize([1, 2, 3, 4, 5])
rdd.foreach(lambda x: print(x)) # runs on executors, not driver
# Common use: write to a database per partition
def write_to_db(records):
# Open DB connection once per partition (foreachPartition is better for this)
import sqlite3
conn = sqlite3.connect("/tmp/test.db")
for record in records:
conn.execute("INSERT INTO logs VALUES (?)", (record,))
conn.commit()
conn.close()
# foreachPartition — better for DB writes (one connection per partition)
rdd.foreachPartition(write_to_db)
# countByValue — count occurrences of each value (returns Python dict)
words = sc.parallelize(["a", "b", "a", "c", "b", "a"])
print(words.countByValue())
# {'a': 3, 'b': 2, 'c': 1} — Python dict (driver memory)
saveAsTextFile(path) writes the RDD to a directory as text files. Each partition becomes one file. Writes one line per element (calls str(element)).
rdd = sc.parallelize(["Alice,30", "Bob,25", "Charlie,35"])
# Write to a local directory (creates part-00000, part-00001, ... files)
rdd.saveAsTextFile("output/employees")
# Creates: output/employees/part-00000, part-00001, etc.
# Write to HDFS
rdd.saveAsTextFile("hdfs://namenode:9000/output/employees")
# Write to S3
rdd.saveAsTextFile("s3://my-bucket/output/employees")
# saveAsPickleFile — saves Python objects in pickle format
rdd.saveAsPickleFile("output/employees_pickle")
# Statistics summary action
nums = sc.parallelize([10, 20, 30, 40, 50])
stats = nums.stats()
print(stats)
# (count: 5, mean: 30.0, stdev: 14.142, max: 50.0, min: 10.0)
| Category | Returns | Triggers Execution? | Examples |
|---|---|---|---|
| Transformation | New RDD | No (lazy) | map, filter, flatMap, groupByKey, reduceByKey |
| Action | Value / Side-effect | YES | collect, count, first, take, reduce, foreach, saveAsTextFile |
Pair RDDs — Key-Value Operations
A Pair RDD is an RDD of (key, value) tuples. This is the most important RDD type for real data processing. Most of Spark's aggregation, join, and grouping operations work on Pair RDDs.
# Direct creation
pair_rdd = sc.parallelize([("Alice", 1000), ("Bob", 1500), ("Alice", 2000)])
# From a regular RDD using map
employees = sc.parallelize(["Alice,Engineering,90000", "Bob,Marketing,70000", "Charlie,Engineering,95000"])
pair_rdd = employees.map(lambda line: (line.split(",")[1], int(line.split(",")[2])))
# pair_rdd: [('Engineering', 90000), ('Marketing', 70000), ('Engineering', 95000)]
rdd = sc.parallelize([("a", 1), ("b", 2), ("a", 3), ("b", 4)])
# keys() and values() — extract keys or values only
print(rdd.keys().collect()) # ['a', 'b', 'a', 'b']
print(rdd.values().collect()) # [1, 2, 3, 4]
# mapValues() — apply function to values only (key passes through unchanged)
doubled = rdd.mapValues(lambda v: v * 2)
print(doubled.collect()) # [('a', 2), ('b', 4), ('a', 6), ('b', 8)]
# flatMapValues() — map + flatten on value side
rdd2 = sc.parallelize([("a", [1, 2, 3]), ("b", [4, 5])])
print(rdd2.flatMapValues(lambda v: v).collect())
# [('a', 1), ('a', 2), ('a', 3), ('b', 4), ('b', 5)]
Pair RDDs support SQL-like joins. Both RDDs must have the same key type.
employees = sc.parallelize([
(1, "Alice"), (2, "Bob"), (3, "Charlie")
])
salaries = sc.parallelize([
(1, 90000), (2, 70000), (4, 85000) # id 3 missing, id 4 extra
])
# INNER JOIN — only matching keys
print(employees.join(salaries).collect())
# [(1, ('Alice', 90000)), (2, ('Bob', 70000))] ← Charlie and id 4 excluded
# LEFT OUTER JOIN — all employees, salary=None if no match
print(employees.leftOuterJoin(salaries).collect())
# [(1, ('Alice', 90000)), (2, ('Bob', 70000)), (3, ('Charlie', None))]
# RIGHT OUTER JOIN — all salaries, name=None if no match
print(employees.rightOuterJoin(salaries).collect())
# [(1, ('Alice', 90000)), (2, ('Bob', 70000)), (4, (None, 85000))]
# FULL OUTER JOIN
print(employees.fullOuterJoin(salaries).collect())
# [(1, ('Alice', 90000)), (2, ('Bob', 70000)), (3, ('Charlie', None)), (4, (None, 85000))]
# cogroup — group all values from multiple RDDs by key
grouped = employees.cogroup(salaries)
for key, (emp_iter, sal_iter) in grouped.collect():
print(f"Key {key}: emps={list(emp_iter)}, sals={list(sal_iter)}")
rdd = sc.parallelize([("a", 1), ("b", 2), ("a", 3), ("c", 4)])
# countByKey — count elements per key (returns dict to driver)
print(rdd.countByKey()) # {'a': 2, 'b': 1, 'c': 1}
# lookup — get all values for a specific key
print(rdd.lookup("a")) # [1, 3]
# collectAsMap — collect as Python dict (unique keys only!)
rdd2 = sc.parallelize([("a", 1), ("b", 2), ("c", 3)])
print(rdd2.collectAsMap()) # {'a': 1, 'b': 2, 'c': 3}
# ⚠️ If duplicate keys exist, only last value is kept!
RDD Partitioning
Partitioning determines how data is distributed across executors. The right partitioning strategy can dramatically improve join and groupBy performance by avoiding unnecessary shuffles.
When you do a join or groupByKey, Spark needs to make sure all values with the same key end up on the same executor. If both RDDs are already partitioned on the same key with the same partitioner, Spark can skip the shuffle — a massive performance win.
Hash partitioning assigns each key to a partition using hash(key) % numPartitions. Same key always goes to the same partition. This is the default when you do operations like groupByKey or reduceByKey.
from pyspark import HashPartitioner
data = sc.parallelize([
("Alice", 1), ("Bob", 2), ("Alice", 3),
("Charlie", 4), ("Bob", 5), ("Dave", 6)
])
# Apply hash partitioner with 3 partitions
partitioned = data.partitionBy(3) # HashPartitioner(3) by default
# Check partition contents
def show_partitions(rdd):
return rdd.mapPartitionsWithIndex(
lambda idx, it: [(f"P{idx}", elem) for elem in it]
).collect()
for p, elem in show_partitions(partitioned):
print(f"{p}: {elem}")
# All "Alice" entries → same partition, "Bob" → same partition, etc.
# This is the whole point: grouped by key hash
# Check if RDD has a partitioner
print(data.partitioner) # None (no partitioner)
print(partitioned.partitioner) # <pyspark.rdd.Partitioner ...>
print(partitioned.getNumPartitions()) # 3
Range partitioning distributes keys across partitions such that each partition contains a contiguous range of key values. Useful when you need output sorted by key or when your keys have a natural order.
from pyspark import RangePartitioner
data = sc.parallelize([
(1, "a"), (5, "b"), (3, "c"), (7, "d"),
(2, "e"), (8, "f"), (4, "g"), (6, "h")
])
# sortByKey uses range partitioning internally
sorted_rdd = data.sortByKey(numPartitions=3)
# Partition 0: keys 1-3, Partition 1: keys 4-6, Partition 2: keys 7-8
# Manual range partitioner
range_partitioned = data.partitionBy(3, lambda key: key % 3)
# keys % 3 == 0 → partition 0, % 3 == 1 → partition 1, etc.
# Why range partitioning?
# After sortByKey, each partition has sorted, non-overlapping key ranges.
# Reading partition N gives you the Nth chunk of sorted data — no reshuffle needed.
# Bad pattern: no pre-partitioning (2 shuffles for 2 joins)
users = sc.parallelize([(1, "Alice"), (2, "Bob")])
orders = sc.parallelize([(1, "Order1"), (2, "Order2")])
payments = sc.parallelize([(1, "Pay1"), (2, "Pay2")])
result1 = users.join(orders) # shuffle
result2 = result1.join(payments) # another shuffle
# Better pattern: pre-partition users (used in multiple joins)
# Then persist — partitioner is remembered!
users_p = users.partitionBy(3).cache() # partitioned + cached
result1 = users_p.join(orders) # orders shuffled to match users_p, users_p stays
result2 = users_p.join(payments) # only payments shuffled — users_p reused from cache
partitionBy(n) and then cache(), Spark remembers the partitioner. Subsequent join or groupByKey calls with the same key and partition count will skip the shuffle on the pre-partitioned RDD.
RDD Persistence — cache, persist, unpersist
By default, Spark recomputes an RDD from scratch every time an action is called. Persistence tells Spark to keep the computed data in memory (or disk) so it doesn't have to recompute it.
# Without persistence — expensive_transformation runs 3 times!
rdd = sc.textFile("huge_file.txt")
processed = rdd.flatMap(lambda l: l.split(" ")).map(lambda w: (w, 1)).reduceByKey(lambda a,b: a+b)
count = processed.count() # full recomputation
top10 = processed.top(10) # full recomputation again
saved = processed.saveAsTextFile("out") # full recomputation again
# With cache() — computed once, stored in memory
processed = (rdd
.flatMap(lambda l: l.split(" "))
.map(lambda w: (w, 1))
.reduceByKey(lambda a,b: a+b)
.cache() # ← mark for caching
)
count = processed.count() # computed AND cached here
top10 = processed.top(10) # served from cache
saved = processed.saveAsTextFile("out") # served from cache
# Always release cache when done
processed.unpersist()
| Storage Level | Memory | Disk | Serialized | Replicated | Use Case |
|---|---|---|---|---|---|
MEMORY_ONLY | Yes | No | No | No | Default cache() — fast, use when data fits in RAM |
MEMORY_AND_DISK | Yes | Yes | No | No | Data may spill to disk if memory is full |
MEMORY_ONLY_SER | Yes | No | Yes | No | Less memory, slightly slower — Kryo serialization helps |
MEMORY_AND_DISK_SER | Yes | Yes | Yes | No | Best balance for most production use cases |
DISK_ONLY | No | Yes | No | No | When memory is too small — slow but saves memory |
MEMORY_ONLY_2 | Yes | No | No | 2x | Replicated for fault tolerance (expensive) |
OFF_HEAP | Off-heap | No | Yes | No | Avoids GC pressure — advanced tuning |
from pyspark import StorageLevel
rdd = sc.parallelize(range(1000000))
# cache() = persist(StorageLevel.MEMORY_ONLY)
rdd.cache()
# persist() with explicit storage level
rdd.persist(StorageLevel.MEMORY_AND_DISK)
rdd.persist(StorageLevel.MEMORY_ONLY_SER)
rdd.persist(StorageLevel.DISK_ONLY)
# Release from memory when done
rdd.unpersist()
# Check if cached
print(rdd.is_cached) # True/False
df.cache() instead).
Shared Variables — Broadcast & Accumulators
Normally, when Spark executes a function on a remote executor, it sends a copy of all variables used. Shared variables are special variables that avoid this: Broadcast Variables share read-only data efficiently, and Accumulators let executors write back aggregated values to the driver.
When your closure (lambda/function) references a Python variable, Spark serializes and sends that variable to EVERY task — even if it's a 1 GB lookup table sent to 1000 tasks. That's 1 TB of network traffic!
# ── Problem: lookup table sent with every task ──
city_to_country = {"New York": "USA", "London": "UK", "Tokyo": "Japan", ...}
# If this dict is 100MB and there are 1000 tasks → 100GB sent over network!
# ── Solution: Broadcast it ──
city_map = sc.broadcast({
"New York": "USA",
"London": "UK",
"Tokyo": "Japan",
"Paris": "France"
})
# Sent ONCE to each executor. All tasks on that executor share the same copy.
# Use .value to access the broadcast variable inside a function
cities_rdd = sc.parallelize(["New York", "London", "Unknown", "Tokyo"])
result = cities_rdd.map(lambda city: (city, city_map.value.get(city, "Unknown")))
print(result.collect())
# [('New York', 'USA'), ('London', 'UK'), ('Unknown', 'Unknown'), ('Tokyo', 'Japan')]
# Destroy broadcast variable when done (releases memory on all executors)
city_map.unpersist() # removes from executor memory
city_map.destroy() # removes from both executor and driver
# Large RDD of transactions (millions of rows)
transactions = sc.parallelize([
("TXN001", "NYC", 500.0),
("TXN002", "LON", 750.0),
("TXN003", "TKY", 300.0)
])
# Small lookup dict — broadcast it!
city_codes = sc.broadcast({
"NYC": "New York",
"LON": "London",
"TKY": "Tokyo"
})
# Enrich transactions with city names — no shuffle needed!
enriched = transactions.map(lambda t: (
t[0],
city_codes.value.get(t[1], t[1]),
t[2]
))
print(enriched.collect())
# [('TXN001', 'New York', 500.0), ('TXN002', 'London', 750.0), ('TXN003', 'Tokyo', 300.0)]
Accumulators are variables that executors can only add to (not read). The driver is the only one that can read the accumulated value. They are used for counters and sums — like counting bad records, null values, or skipped lines.
# Basic numeric accumulator
error_count = sc.accumulator(0) # starts at 0
null_count = sc.accumulator(0)
data = sc.parallelize([
("Alice", 100), (None, 200), ("Bob", None),
("Charlie", 300), (None, None)
])
def process(record):
name, amount = record
if name is None:
null_count.add(1) # increment on executor
if amount is None:
error_count.add(1) # increment on executor
return (name, 0) # default value
return (name, amount)
result = data.map(process)
result.collect() # action triggers execution and accumulation
# Read accumulated values on the DRIVER
print(f"Null names: {null_count.value}") # 2
print(f"Error amounts: {error_count.value}") # 2
from pyspark import AccumulatorParam
# Custom accumulator for tracking a Python set (e.g., unique error codes)
class SetAccumulator(AccumulatorParam):
def zero(self, init_val):
return set(init_val)
def addInPlace(self, s1, s2):
s1 |= s2 # set union
return s1
error_codes = sc.accumulator(set(), SetAccumulator())
logs = sc.parallelize(["E001", "E002", "E001", "E003", "E002"])
logs.foreach(lambda code: error_codes.add({code}))
print(error_codes.value) # {'E001', 'E002', 'E003'}
foreach), not transformations. For production use cases, prefer a proper aggregation via count() or filter().count().
| Feature | Broadcast Variable | Accumulator |
|---|---|---|
| Direction | Driver → Executors | Executors → Driver |
| Access on Executor | Read only (.value) | Write only (.add()) |
| Access on Driver | Read/Write | Read only (.value) |
| Use Case | Lookup tables, config, small datasets | Counters, sums, metrics collection |
| Size | Can be large (GBs), sent once per executor | Should be small (single values) |