MODULE 18C State Management
1 / 10
18C.1
State Store
The state store is Spark's memory system for streaming โ€” it remembers facts across micro-batches so stateful operations like aggregations, joins, and deduplication can work correctly over time.
๐Ÿ—„๏ธ
What is the State Store?
CORE โ–ผ
The Problem State Solves
Without state, each micro-batch is completely isolated โ€” it sees only the rows that arrived in that 10-second window. But what if you need to count total orders since the query started? Or detect when a user has placed 5 orders over the past hour? That requires memory that survives across batches.

The state store is a key-value store that lives on executor memory (and disk) and persists between batches. Every stateful operator reads from and writes to this store.
๐Ÿง  Analogy
Think of the state store like a whiteboard in an office. Each employee (batch) can read what was written before, add new information, and erase old notes. When a new employee arrives (next batch), the whiteboard is exactly as the previous one left it. Without the whiteboard, each employee would start with zero knowledge.
๐Ÿ”ข
Key-Value Store
State is stored as key โ†’ value pairs. For groupBy("user_id"), the key is the user_id and the value is the aggregated result (count, sum, etc.)
๐Ÿ’พ
Persisted to Checkpoint
State is snapshotted to the checkpoint directory periodically so it can be recovered after a driver or executor failure
๐Ÿ”€
Partition-Aligned
Each executor partition owns a slice of the state. State for key "user_123" always lives on the same partition (determined by hash of key)
โฐ
Grows Over Time
Without cleanup, state accumulates forever โ€” one entry per unique key seen. This is the #1 production issue with stateful streaming
RocksDB State Store
Spark offers two state store implementations. RocksDB is the recommended production choice since Spark 3.2:
FeatureIn-Memory (default)RocksDB
Storage locationJVM heap (executor memory)Off-heap + local disk (SSD)
State size limitLimited by executor heapCan exceed heap (spills to disk)
GC pressureHigh โ€” large objects on heapLow โ€” off-heap storage
Checkpoint speedFull snapshot every timeIncremental (only changed keys)
Read/write speedFaster for small stateSlightly slower for tiny state
Best forDev/testing, small stateProduction, large state
python โ€” Enabling RocksDB State Store
spark = SparkSession.builder \
    .appName("StatefulStreaming") \
    .config(
        "spark.sql.streaming.stateStore.providerClass",
        "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider"
    ) \
    .config("spark.sql.shuffle.partitions", "8") \  # controls # of state partitions
    .getOrCreate()

# RocksDB stores state in:
#   {checkpointDir}/state/{operatorId}/{partitionId}/
#
# With incremental checkpointing, only CHANGED keys
# are written per batch โ€” much faster than full snapshots
In-Memory State Store
The default state store keeps all state in executor JVM heap as a HashMap. Simple and fast for small state, but problematic at scale:
โ€ข State grows โ†’ heap fills โ†’ GC pauses increase โ†’ batch latency spikes
โ€ข Full state snapshot written to checkpoint every batch (even if nothing changed)
โ€ข OOM errors when state exceeds executor memory
python โ€” Default in-memory state store (no extra config needed)
# Default โ€” no config needed, uses in-memory store
# Fine for dev and small production workloads
spark = SparkSession.builder \
    .appName("SmallStateApp") \
    .getOrCreate()

# Monitor state size to know when to switch to RocksDB:
# query.lastProgress['stateOperators'][0]['numRowsTotal']
# If this keeps growing and approaches executor memory โ†’ switch to RocksDB
State Store Provider Configuration
You can also set the state store provider at the Spark configuration level (cluster-wide) or per-session:
python โ€” All state store config options
spark = SparkSession.builder \
    .appName("ProductionStateful") \
    # Switch to RocksDB
    .config("spark.sql.streaming.stateStore.providerClass",
            "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") \
    # RocksDB: how many MB of state to keep in RocksDB block cache (default 8MB)
    .config("spark.sql.streaming.stateStore.rocksdb.blockCacheSizeMB", "32") \
    # RocksDB: when to write a full snapshot vs incremental (default: every 1000 batches)
    .config("spark.sql.streaming.stateStore.rocksdb.resetStatsOnLoad", "false") \
    # How many historical versions of state to keep on checkpoint (for recovery)
    .config("spark.sql.streaming.minBatchesToRetain", "2") \
    .getOrCreate()
State Store Location in Checkpoint
State is stored within the checkpoint directory under state/:
{checkpointDir}/
โ”œโ”€โ”€ state/
โ”‚ โ”œโ”€โ”€ 0/ # operator 0 (e.g., groupBy aggregation)
โ”‚ โ”‚ โ”œโ”€โ”€ 0/ # partition 0
โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ 1.delta # changes from batch 1
โ”‚ โ”‚ โ”‚ โ”œโ”€โ”€ 2.delta # changes from batch 2
โ”‚ โ”‚ โ”‚ โ””โ”€โ”€ 5.snapshot # full snapshot at batch 5
โ”‚ โ”‚ โ””โ”€โ”€ 1/ # partition 1
โ”‚ โ””โ”€โ”€ 1/ # operator 1 (e.g., dedup)
โ”œโ”€โ”€ offsets/
โ””โ”€โ”€ commits/
18C.2
Key Based State
State in Spark is always organized by key. Understanding how groupBy keys map to state partitions is essential for performance and debugging.
๐Ÿ”‘
How State is Keyed
FUNDAMENTALS โ–ผ
groupBy Key as State Key
When you write .groupBy("user_id").agg(count("*")), Spark creates one state entry per unique user_id. The state store maps:

key = (user_id) โ†’ value = {count: 42, sum: 1500, ...}

The "key" in the state store is literally the value of your groupBy column(s). For compound keys like .groupBy("user_id", "product_category"), the key is a tuple (user_id, product_category).
๐Ÿ—„๏ธ State Store โ€” Batch 1
user_001count=3, sum=150
user_002count=1, sum=80
user_003count=5, sum=320
๐Ÿ—„๏ธ State Store โ€” Batch 2 (updated)
user_001count=5, sum=280
user_002count=1, sum=80
user_003count=7, sum=470
user_004count=2, sum=90
python โ€” Stateful groupBy โ€” each unique key gets its own state entry
from pyspark.sql.functions import count, sum as spark_sum, to_timestamp, col

# Each unique user_id โ†’ 1 state entry
# State entry stores: running count + running sum
agg_df = orders_stream \
    .withWatermark("event_time", "10 minutes") \
    .groupBy("user_id") \  # โ† this becomes the state KEY
    .agg(
        count("*").alias("order_count"),       # state VALUE field
        spark_sum("amount").alias("total_spent")  # state VALUE field
    )

# Compound key example:
# state key = (user_id, product_category) tuple
agg_df2 = orders_stream \
    .withWatermark("event_time", "10 minutes") \
    .groupBy("user_id", "product_category") \
    .agg(count("*").alias("count"))
Partitioning and State Locality
State is partitioned โ€” each executor partition owns a subset of keys. The partition for a given key is determined by: hash(key) % numPartitions.

numPartitions is controlled by spark.sql.shuffle.partitions (default 200). For streaming, this is critically important: too many partitions โ†’ too many small state files in checkpoint; too few โ†’ state hotspots on one executor.
๐Ÿ’ก State Locality Rule
All rows with the same groupBy key are always shuffled to the same partition. This means state for "user_001" always lives on the same executor partition โ€” no cross-partition state lookups needed. This is why streaming groupBy requires a shuffle.
python โ€” Tuning shuffle partitions for state management
# Rule of thumb for streaming:
# numPartitions โ‰ˆ 2-3x number of executor cores
# (too many โ†’ small state files, overhead; too few โ†’ hotspots)

# For a cluster with 4 executors x 4 cores = 16 cores:
spark.conf.set("spark.sql.shuffle.partitions", "32")

# IMPORTANT: Once a streaming query starts, you CANNOT change
# spark.sql.shuffle.partitions without deleting the checkpoint.
# Changing it changes the partition layout โ†’ state keys move โ†’ broken!

# Verify your setting:
print(spark.conf.get("spark.sql.shuffle.partitions"))
โš ๏ธ Never Change shuffle.partitions Mid-Stream
Changing spark.sql.shuffle.partitions after a streaming query has started (even after a restart) means the state partition layout changes. State for key "user_001" was in partition 5 before, now it's in partition 12. Spark will not find it โ†’ wrong results. Always delete the checkpoint when changing this value.
18C.3
State Growth
The #1 production problem with stateful streaming โ€” state grows unboundedly unless you actively manage it. Here's how to detect it and understand why it happens.
๐Ÿ“ˆ
Unbounded State Problem
DANGER โ–ผ
Why State Grows Forever
Without watermarking or TTL, the state store keeps every key it has ever seen. A streaming query running for months on a high-cardinality key (like user_id or session_id) accumulates millions of entries, eventually causing:
โ€ข Executor OOM errors
โ€ข Increasing GC pause times (with in-memory store)
โ€ข Checkpoint files growing to hundreds of GB
โ€ข Batch processing time keeps increasing
Example: The Problem
A clickstream pipeline groups by session_id. Each session_id is unique โ€” never repeats. After 1 month: 10M sessions/day ร— 30 days = 300M state entries. At ~200 bytes each = 60 GB of state. Executor memory is 16 GB โ†’ OOM.
python โ€” Monitoring state growth
import time

# Monitor state size after every few batches
while query.isActive:
    time.sleep(30)
    progress = query.lastProgress
    if progress and progress.get("stateOperators"):
        for op in progress["stateOperators"]:
            print(f"State rows: {op['numRowsTotal']}")         # TOTAL keys in store
            print(f"State memory: {op['memoryUsedBytes']} bytes")   # memory footprint
            print(f"Rows updated: {op['numRowsUpdated']}")        # keys touched this batch
            print(f"Rows removed: {op['numRowsRemoved']}")        # keys cleaned up this batch
            print("---")

# Healthy pattern: numRowsRemoved > 0 (cleanup happening)
# Danger signal: numRowsTotal keeps increasing, numRowsRemoved = 0
State Accumulation Over Time
The table below shows how state explodes over time for a query with NO watermark on a high-cardinality key:
Time RunningUnique Keys SeenState EntriesApprox State SizeStatus
1 hour10,00010,000~2 MBโœ… Fine
1 day240,000240,000~48 MBโœ… Fine
1 week1.68M1.68M~336 MBโš ๏ธ Watch
1 month7.2M7.2M~1.4 GBโŒ OOM risk
6 months43M43M~8.6 GBโŒ OOM certain
Memory and Disk Impact
The impact differs by state store type:

In-memory store: All state is on JVM heap. As state grows, GC pauses get longer (seconds to minutes of GC), batch latency explodes, eventually OOM kills the executor.

RocksDB store: State spills to local disk after block cache fills. Disk I/O increases as state grows, checkpoint sizes balloon, but you won't OOM as easily. However, SSD space can be exhausted.
18C.4
State Cleanup
Three strategies to prevent state from growing forever: watermark-based expiration, TTL-based expiration, and explicit removal in mapGroupsWithState.
๐Ÿงน
State Expiration Strategies
CRITICAL โ–ผ
TTL-Based Expiration (mapGroupsWithState)
Time-To-Live (TTL) is the most explicit cleanup mechanism. You set a timeout on a state key, and Spark automatically evicts it if no new data arrives within that window.

Available via mapGroupsWithState and flatMapGroupsWithState using the GroupState API. Two timeout types:
โ€ข GroupStateTimeout.ProcessingTimeTimeout โ€” expires based on wall-clock time
โ€ข GroupStateTimeout.EventTimeTimeout โ€” expires based on watermark
python โ€” TTL via GroupStateTimeout
from pyspark.sql.streaming import GroupState, GroupStateTimeout
from pyspark.sql.types import StructType, StringType, LongType
from typing import Iterator, Tuple

# State schema: what we remember per user
StateType = StructType() \
    .add("order_count", LongType()) \
    .add("total_amount", LongType())

def update_user_state(
    user_id: str,
    new_orders,
    state: GroupState
) -> Tuple[str, int, int]:
    """
    Called per unique user_id per batch.
    state: the persisted state for this user (survives across batches)
    """
    
    # Check if this was a TIMEOUT call (no new data, state expired)
    if state.hasTimedOut:
        # Clean up expired state โ€” user hasn't ordered in 30 minutes
        state.remove()
        return (user_id, 0, 0)  # or yield nothing for flatMap version
    
    # Get current state (or initialize if first time)
    if state.exists:
        curr_count = state.get["order_count"]
        curr_amount = state.get["total_amount"]
    else:
        curr_count, curr_amount = 0, 0
    
    # Process new rows for this user in this batch
    for order in new_orders:
        curr_count += 1
        curr_amount += order.amount
    
    # Update state
    state.update((curr_count, curr_amount))
    
    # Set TTL: if no new orders arrive for 30 min โ†’ evict this user's state
    state.setTimeoutDuration("30 minutes")  # ProcessingTimeTimeout
    
    return (user_id, curr_count, curr_amount)

# Apply the stateful function
output_schema = "user_id STRING, order_count LONG, total_amount LONG"

result = orders_stream \
    .groupBy("user_id") \
    .applyInPandas(update_user_state, schema=output_schema)
# Note: mapGroupsWithState uses .mapGroupsWithState() for typed Dataset API
Watermark-Based Expiration
The most common and automatic cleanup mechanism. When you use .withWatermark() on a groupBy aggregation, Spark automatically evicts state for keys whose event_time window has passed the watermark.

This is the recommended approach for window aggregations โ€” no manual cleanup code needed.
python โ€” Automatic state cleanup via watermark
from pyspark.sql.functions import window, count, sum as spark_sum

# Tumbling window: count orders per 1-hour window
# Watermark: tolerate up to 10 minutes of late data
# State cleanup: once watermark passes the window end โ†’ evict that window's state

windowed_agg = orders_stream \
    .withWatermark("event_time", "10 minutes") \  # โ† cleanup trigger
    .groupBy(
        window(col("event_time"), "1 hour"),  # tumbling 1-hour window
        col("product_category")
    ) \
    .agg(
        count("*").alias("order_count"),
        spark_sum("amount").alias("revenue")
    )

# State lifecycle for window [10:00โ€“11:00):
# Batch 1: watermark = 09:45 โ†’ window still open โ†’ state kept
# Batch 2: watermark = 10:50 โ†’ window still open โ†’ state kept
# Batch 3: watermark = 11:10 โ†’ window [10:00-11:00) is PAST watermark
#   โ†’ state for this window is EVICTED โœ…
#   โ†’ result emitted in Append mode (window is now final)

query = windowed_agg.writeStream \
    .outputMode("append") \  # append: only emit when window is finalized
    .option("checkpointLocation", "/tmp/ckpt/windowed") \
    .format("delta") \
    .start("/delta/hourly_revenue")
๐Ÿ—„๏ธ State Before Cleanup (Watermark = 10:55)
[09:00โ€“10:00) Electronicscount=42
[10:00โ€“11:00) Electronicscount=87
[10:00โ€“11:00) Bookscount=23
๐Ÿ—„๏ธ State After Cleanup (Watermark = 11:15)
[09:00โ€“10:00) Electronicsevicted
[10:00โ€“11:00) Electronicsevicted
[10:00โ€“11:00) Booksevicted
[11:00โ€“12:00) Electronicscount=15
Explicit State Removal in mapGroupsWithState
In mapGroupsWithState / flatMapGroupsWithState, you can call state.remove() at any time to explicitly delete the state for a key โ€” giving you full programmatic control.
python โ€” Explicit state.remove() example
def session_tracker(session_id, events, state: GroupState):
    # Explicitly remove state when session is "completed"
    for event in events:
        if event.event_type == "SESSION_END":
            # Session is done โ€” no point keeping state
            state.remove()  # โ† explicit cleanup
            return (session_id, "CLOSED", state.get["page_views"] if state.exists else 0)
    
    # Otherwise, update state and set TTL
    curr = state.get["page_views"] if state.exists else 0
    state.update((curr + len(list(events)),))
    state.setTimeoutDuration("30 minutes")  # auto-cleanup if no events for 30min
    return (session_id, "ACTIVE", state.get["page_views"])
Automatic Cleanup with Watermark
For simple window-based aggregations, watermark-based cleanup happens automatically โ€” no extra code. The rule is:

State for a window is evicted when: watermark > window_end_time

This guarantees bounded state โ€” state size stabilizes rather than growing forever. The amount of state at any time โ‰ˆ the watermark delay window's worth of data.
โœ… Best Practice
Always use .withWatermark() on any stateful aggregation in production. This is the single most important thing you can do to prevent unbounded state growth. If you skip it, your query will eventually OOM.
Manual Cleanup Strategies
When watermark doesn't fit your use case (e.g., non-time-based state), use these patterns:
1. TTL in mapGroupsWithState: state.setTimeoutDuration("30 minutes")
2. Count-based eviction: Remove state after N events processed for a key
3. Flag-based eviction: Remove when business logic says "done" (e.g., order DELIVERED)
4. Periodic restart: Restart query with fresh checkpoint (nuclear option โ€” avoid)
18C.5
Memory Consumption
How state occupies memory differently in each store type, and how to monitor and control it before it becomes a production incident.
๐Ÿ’ฐ
State Memory Footprint
PERFORMANCE โ–ผ
State Memory Footprint
Each state entry consumes memory for both the key and value, plus overhead. Rough estimates:
State ContentKey SizeValue SizeOverheadTotal per Entry
groupBy(string) + count/sum~20โ€“50 bytes~16 bytes~64 bytes~100โ€“130 bytes
groupBy(string) + collect_list~20 bytesVariable (grows!)~64 bytesUnbounded per key!
deduplication~20โ€“50 bytes~4 bytes~64 bytes~88โ€“118 bytes
stream-stream join~40 bytesFull row (~200 bytes)~64 bytes~300 bytes
โš ๏ธ Never Use collect_list in Stateful Aggregation
collect_list in a streaming aggregation stores EVERY row for a key in state. If user_001 places 10,000 orders, the state entry for user_001 stores all 10,000 rows. This makes state size proportional to total events, not total keys โ€” will OOM immediately.
RocksDB Off-Heap Usage
RocksDB stores state in a block cache (off-heap memory) and on local disk. The block cache is like an LRU cache โ€” hot keys stay in memory, cold keys are on disk.

Configure off-heap memory to give RocksDB enough room:
python โ€” RocksDB memory configuration
spark = SparkSession.builder \
    .appName("RocksDBMemoryTuning") \
    # Enable RocksDB
    .config("spark.sql.streaming.stateStore.providerClass",
            "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") \
    # Off-heap memory for RocksDB block cache per partition
    .config("spark.sql.streaming.stateStore.rocksdb.blockCacheSizeMB", "64") \
    # Enable off-heap memory (required for RocksDB)
    .config("spark.memory.offHeap.enabled", "true") \
    .config("spark.memory.offHeap.size", "2g") \   # per executor
    # Total executor memory = spark.executor.memory + offHeap
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

# Memory layout per executor:
#   4g  on-heap  (for Spark execution, shuffle buffers, etc.)
#   2g  off-heap (for RocksDB block cache)
#   Total JVM memory request: ~6g + JVM overhead (~1g) = ~7g
In-Memory State Heap Usage
The in-memory state store puts all state on the JVM heap. This competes with Spark's execution memory (for shuffles and joins), storage memory (for caching), and the JVM itself.

Spark's unified memory manager doesn't explicitly reserve space for state โ€” the state store just uses whatever heap is available. This leads to unpredictable OOM behavior when state grows.
python โ€” Diagnosing in-memory state OOM
# Signs you're running out of state memory:
# 1. Executor logs show: java.lang.OutOfMemoryError: Java heap space
# 2. GC time in Spark UI keeps increasing per batch
# 3. stateOperators[0].memoryUsedBytes approaches executor heap size

# Check via query progress:
progress = query.lastProgress
state_ops = progress.get("stateOperators", [])
for op in state_ops:
    mem_mb = op["memoryUsedBytes"] / (1024 * 1024)
    rows = op["numRowsTotal"]
    print(f"State: {rows:,} rows, {mem_mb:.1f} MB")
    
    # Alarm if state exceeds 80% of executor memory
    executor_heap_mb = 4096  # 4GB executor
    if mem_mb > executor_heap_mb * 0.8:
        print(f"โš ๏ธ  WARNING: State is {mem_mb/executor_heap_mb*100:.0f}% of executor heap!")
        print("Consider: (1) Enable watermark (2) Switch to RocksDB")
State Size Metrics in Spark UI
In the Spark UI, navigate to the Structured Streaming tab. Each streaming query shows:
โ€ข numRowsTotal โ€” total keys in the state store right now
โ€ข memoryUsedBytes โ€” current state memory footprint
โ€ข numRowsDroppedByWatermark โ€” late rows rejected (not evicted state)
โ€ข numRowsRemoved โ€” state keys evicted this batch

A healthy query shows numRowsRemoved > 0 regularly. A query with numRowsRemoved = 0 and numRowsTotal ever-increasing is heading for an OOM.
18C.6
Checkpoint Size Management
State is checkpointed to durable storage. Without management, checkpoint directories grow to hundreds of GB. Here's how to control it.
๐Ÿ“ฆ
State Checkpoint Size & Incremental Checkpointing
OPTIMIZATION โ–ผ
State Checkpoint Size
The total checkpoint size = state size ร— versions retained. With in-memory store, every batch writes a full snapshot. With 10 GB of state and 100 batches retained โ†’ 1 TB of checkpoint data. This is why RocksDB's incremental checkpointing matters enormously.
Store TypePer-Batch Checkpoint Write10 GB State, 100 Batches
In-memory (default)Full state snapshot every batch~1 TB checkpoint size
RocksDBOnly changed keys (incremental)~50โ€“200 GB (depends on change rate)
Incremental Checkpointing with RocksDB
RocksDB writes delta files โ€” only the keys that changed in this batch โ€” rather than full state snapshots. Every N batches (configurable), it writes a full snapshot as a baseline for recovery.
python โ€” RocksDB incremental checkpoint config
spark = SparkSession.builder \
    .config("spark.sql.streaming.stateStore.providerClass",
            "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") \
    # How often to write a full snapshot (default: 1000 batches)
    .config("spark.sql.streaming.stateStore.rocksdb.minDeltasForSnapshot", "1000") \
    # Checkpoint format: native (default), or zipped
    .config("spark.sql.streaming.stateStore.rocksdb.enableChangelogCheckpointing", "true") \
    .getOrCreate()

# With changelog checkpointing:
# batch 1: writes keys {A:1, B:2} โ†’ delta file
# batch 2: key A updated to {A:5} โ†’ delta file (only A written)
# batch 1000: full snapshot written
# Recovery: apply snapshot + all deltas since snapshot
Checkpoint Retention Policy
Spark keeps the last N versions of state for recovery. This is controlled by spark.sql.streaming.minBatchesToRetain (default: 100).

You only need enough versions to recover from a failure. In practice, 2โ€“5 versions is sufficient for most pipelines.
python โ€” Controlling checkpoint retention
# Reduce how many old checkpoint versions are kept
spark.conf.set("spark.sql.streaming.minBatchesToRetain", "5")
# default is 100 โ€” this keeps 95 unnecessary old versions!

# What this controls:
# - Old offset log files (older than minBatchesToRetain batches) are cleaned
# - Old commit log files are cleaned
# - Old state snapshots are cleaned (RocksDB keeps what's needed for recovery)

# A value of 5 means: keep 5 recent batches of state history
# Good for recovery windows of a few minutes
# If your recovery window is longer, increase this value

# Cleaning old checkpoints manually (if needed):
# hdfs dfs -rm -r /ckpt/orders/state/0/0/1.delta  # don't do manually!
# Always let Spark manage this via minBatchesToRetain
Cleaning Old Checkpoints
Spark automatically cleans old checkpoint versions based on minBatchesToRetain. You should never manually delete checkpoint files while a query is running โ€” this breaks recovery.

To fully reset a query (start fresh): stop the query โ†’ delete the entire checkpoint directory โ†’ restart. This means processing starts from the beginning (or from startingOffsets).
โš ๏ธ Never Delete Checkpoint While Query Runs
Deleting checkpoint files while a streaming query is running causes immediate failure. Always stop the query first. The only safe way to "reset" is: stop query โ†’ delete checkpoint dir โ†’ restart query.
18C.7
mapGroupsWithState
The most powerful and flexible stateful operator in Structured Streaming โ€” lets you implement any arbitrary stateful logic with full control over state reads, updates, and timeouts.
๐ŸŽ›๏ธ
Arbitrary Stateful Processing
ADVANCED โ–ผ
What mapGroupsWithState Does
mapGroupsWithState gives you a custom function that is called once per unique key per batch. You receive:
โ€ข The key value (e.g., user_id = "user_001")
โ€ข An iterator of all new rows for that key in this batch
โ€ข A GroupState object โ€” your interface to the state store

You can read the current state, update it however you like, set a timeout, or remove it entirely. You must return exactly one row per key (unlike flatMapGroupsWithState which can return 0 or many).
๐Ÿ“ฌ Analogy
Think of mapGroupsWithState like a personal assistant for each customer. The assistant has a file (state) about the customer. Every time new mail arrives (new batch rows), the assistant reads the existing file, processes the new mail, updates the file, and files a single report (one output row). If no mail arrives for 30 minutes, the assistant shreds the file (timeout โ†’ state.remove()).
GroupState API
The GroupState object is your handle to the persistent state store for a given key:
python โ€” GroupState API reference
from pyspark.sql.streaming import GroupState, GroupStateTimeout

def my_state_function(key, values, state: GroupState):
    
    # โ”€โ”€ READ state โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    state.exists          # bool: does state exist for this key?
    state.get             # get current state (raises if not exists)
    state.getOption       # returns None if not exists
    
    # โ”€โ”€ WRITE state โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    state.update(new_val) # replace state with new value
    state.remove()        # delete state for this key entirely
    
    # โ”€โ”€ TIMEOUT โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    state.hasTimedOut     # bool: was this call triggered by timeout?
    state.setTimeoutDuration("30 minutes")  # ProcessingTime timeout
    state.setTimeoutTimestamp(ts_ms)         # EventTime timeout (epoch ms)
    
    # โ”€โ”€ METADATA โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    state.getCurrentProcessingTimeMs()  # current processing time
    state.getCurrentWatermarkMs()       # current watermark (for EventTime timeout)
State Timeout โ€” ProcessingTimeTimeout vs EventTimeTimeout
Two timeout types for state expiration:
Timeout TypeTriggers WhenUse CaseRequires
ProcessingTimeTimeoutWall-clock time passes since last update"Expire session if no activity for 30 min real time"setTimeoutDuration("30 minutes")
EventTimeTimeoutWatermark advances past timeout timestamp"Expire session when event time > session_start + 1 hour"setTimeoutTimestamp(ts) + watermark
Update Function Lifecycle
The function is called in two scenarios each batch:
1. Normal call: New rows arrived for this key โ†’ process them
2. Timeout call: No new rows, but state timed out โ†’ state.hasTimedOut == True
python โ€” Full mapGroupsWithState example: User Session Tracker
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_json, to_timestamp
from pyspark.sql.types import (
    StructType, StringType, LongType, IntegerType
)
from pyspark.sql.streaming import GroupState, GroupStateTimeout

spark = SparkSession.builder.appName("SessionTracker").getOrCreate()

# โ”€โ”€ State Schema: what we store per user โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# (page_views, total_time_sec, session_start_ms)
SessionState = StructType() \
    .add("page_views", IntegerType()) \
    .add("total_time_sec", LongType()) \
    .add("session_start_ms", LongType())

# โ”€โ”€ Input Schema โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
EventSchema = StructType() \
    .add("user_id", StringType()) \
    .add("page", StringType()) \
    .add("duration_sec", IntegerType()) \
    .add("event_time", StringType())

# โ”€โ”€ State Update Function โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def track_session(user_id: str, events, state: GroupState):
    """
    Returns one summary row per user per batch.
    State = running session totals.
    """
    
    # Handle timeout: session expired (no activity for 30 min)
    if state.hasTimedOut:
        if state.exists:
            s = state.get
            state.remove()
            return (user_id, s["page_views"], s["total_time_sec"], "EXPIRED")
        return (user_id, 0, 0, "EXPIRED_EMPTY")
    
    # Initialize or restore state
    if state.exists:
        page_views = state.get["page_views"]
        total_time = state.get["total_time_sec"]
        start_ms   = state.get["session_start_ms"]
    else:
        page_views, total_time, start_ms = 0, 0, state.getCurrentProcessingTimeMs()
    
    # Process new events for this user in this batch
    for event in events:
        page_views += 1
        total_time += event.duration_sec
    
    # Persist updated state
    state.update((page_views, total_time, start_ms))
    
    # Reset the 30-minute inactivity timeout
    state.setTimeoutDuration("30 minutes")
    
    return (user_id, page_views, total_time, "ACTIVE")

# โ”€โ”€ Output Schema โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
output_schema = "user_id STRING, page_views INT, total_time_sec LONG, status STRING"

# โ”€โ”€ Pipeline โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
events_stream = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "clickstream") \
    .load() \
    .select(from_json(col("value").cast("string"), EventSchema).alias("e")) \
    .select("e.*")

session_df = events_stream \
    .groupBy("user_id") \
    .mapGroupsWithState(
        track_session,
        outputSchema=output_schema,
        stateSchema=SessionState,
        timeoutConf=GroupStateTimeout.ProcessingTimeTimeout
    )

query = session_df.writeStream \
    .outputMode("update") \  # mapGroupsWithState supports update or append
    .format("delta") \
    .option("checkpointLocation", "/tmp/ckpt/sessions") \
    .trigger(processingTime="10 seconds") \
    .start("/delta/sessions")

query.awaitTermination()
18C.8
flatMapGroupsWithState
The sibling of mapGroupsWithState โ€” same stateful control, but allows emitting zero, one, or many output rows per key per batch. Essential for event-driven patterns.
๐Ÿ”€
flatMapGroupsWithState vs mapGroupsWithState
ADVANCED โ–ผ
The Core Difference
The only difference is what you can return:
OperatorOutput per Key per BatchReturn TypeOutput Modes
mapGroupsWithStateExactly 1 rowSingle tuple/RowUpdate, Append
flatMapGroupsWithState0, 1, or many rowsIterator of tuples/RowsAppend only (or Update)
๐Ÿ“ฌ Analogy
mapGroupsWithState is like a vending machine โ€” you put in a coin (input events), you always get back one snack (one output row), no exceptions. flatMapGroupsWithState is like a lottery ticket โ€” you scratch it and might get nothing, one prize, or multiple prizes.
Use Cases for flatMapGroupsWithState
Use flatMapGroupsWithState when:
โ€ข You only want to emit output when a specific condition is met (e.g., 5th consecutive login from different IP โ†’ emit fraud alert, otherwise emit nothing)
โ€ข You need to emit multiple events from a single state update (e.g., when a session ends, emit one row per page visited)
โ€ข You want no output for intermediate state accumulation and only output at the end
python โ€” flatMapGroupsWithState: Fraud Alert Pattern
from pyspark.sql.streaming import GroupState, GroupStateTimeout
from pyspark.sql.types import StructType, StringType, IntegerType
from typing import Iterator, Tuple

# State: count of failed logins per user
AlertState = StructType().add("failed_attempts", IntegerType())

def detect_fraud(
    user_id: str,
    events,
    state: GroupState
) -> Iterator[Tuple[str, str, int]]:
    """
    flatMapGroupsWithState: can yield 0 or many rows.
    Only emits when fraud is detected.
    """
    
    # Timeout: reset counter after 10 minutes of no activity
    if state.hasTimedOut:
        state.remove()
        return  # yield nothing โ€” no output for timeout cleanup
    
    curr_attempts = state.get["failed_attempts"] if state.exists else 0
    
    for event in events:
        if event.event_type == "LOGIN_FAILED":
            curr_attempts += 1
            
            if curr_attempts >= 5:
                # ALERT! Emit a fraud alert row
                yield (user_id, "FRAUD_ALERT", curr_attempts)  # โ† yield, not return
                state.remove()  # reset after alerting
                curr_attempts = 0
            # else: accumulate silently โ€” NO output emitted
        
        elif event.event_type == "LOGIN_SUCCESS":
            curr_attempts = 0  # reset on successful login
    
    # Update state and reset timeout
    if curr_attempts > 0:
        state.update((curr_attempts,))
        state.setTimeoutDuration("10 minutes")
    else:
        if state.exists:
            state.remove()
    # If no alert triggered โ†’ function returns without yielding โ†’ 0 output rows

# Apply flatMapGroupsWithState
output_schema = "user_id STRING, alert_type STRING, attempt_count INT"

fraud_alerts = login_stream \
    .groupBy("user_id") \
    .flatMapGroupsWithState(
        detect_fraud,
        outputMode="append",      # โ† append required for flatMap
        stateSchema=AlertState,
        outputSchema=output_schema,
        timeoutConf=GroupStateTimeout.ProcessingTimeTimeout
    )

query = fraud_alerts.writeStream \
    .outputMode("append") \
    .format("delta") \
    .option("checkpointLocation", "/tmp/ckpt/fraud") \
    .start("/delta/fraud_alerts")
Output Mode Restrictions
flatMapGroupsWithState has stricter output mode requirements:

โ€ข Append mode (most common for flatMap): emitted rows are appended to the sink, never updated. Works well with Delta Lake, Kafka.
โ€ข Update mode: supported but less common โ€” each batch's output is treated as an update.
โ€ข Complete mode: NOT supported with flatMapGroupsWithState.
๐Ÿ’ก mapGroups vs flatMapGroups โ€” Decision Rule
Use mapGroupsWithState when: you always have one output per key per batch (session summary, running total dashboard)
Use flatMapGroupsWithState when: some batches produce no output (fraud alert, threshold detection), or you emit multiple rows (session events on close)
QUIZ
Quick Quiz โ€” Module 18C
Test your understanding of State Management in Structured Streaming.
1. Why is RocksDB the recommended state store for production over the default in-memory store?
โœ… Correct! RocksDB's three key advantages: off-heap storage (no GC pressure from large state), disk spill capability (handles state larger than executor heap), and incremental checkpointing (10x smaller checkpoint writes than full snapshots).
2. A streaming query with groupBy("user_id") has been running for 3 months with NO watermark. What happens to the state store?
โœ… Exactly! Without watermark or TTL, Spark never evicts state. After 3 months of unique user activity, the state store may contain tens of millions of entries causing OOM. Always use .withWatermark() or setTimeoutDuration() in production.
3. You change spark.sql.shuffle.partitions from 8 to 16 on a streaming query with existing checkpoint. What happens?
โœ… Correct! State is partition-local: hash(key) % numPartitions determines which executor owns it. Changing numPartitions breaks this mapping โ€” "user_001" moves from partition 5 to partition 9, but state is still in partition 5. Always delete checkpoint before changing shuffle.partitions.
4. What is the key difference between mapGroupsWithState and flatMapGroupsWithState?
โœ… Right! The "flat" in flatMapGroupsWithState means the output is flattened โ€” you yield multiple rows or nothing. mapGroupsWithState always produces exactly one row. Use flatMap when you conditionally emit (fraud alerts) or emit many rows (session decomposition).
5. In mapGroupsWithState, when is the update function called with state.hasTimedOut == True?
โœ… Perfect! hasTimedOut is a special "cleanup callback". Spark calls your function one last time per timed-out key โ€” with no new events (empty iterator) and hasTimedOut=True โ€” giving you a chance to emit a final row and call state.remove() before the state is evicted.
CHEAT SHEET
Module 18C โ€” Quick Reference
Everything about State Management in one place.
๐Ÿ“‹
Core Concepts & Comparison Tables
โ–ผ
ConceptKey FactConfig / API
In-memory storeDefault. State on JVM heap. Full snapshot per batch.(no config needed)
RocksDB storeOff-heap + disk. Incremental checkpoint. Production choice.stateStore.providerClass = RocksDBStateStoreProvider
State keyThe groupBy column value(s). Hash-partitioned..groupBy("user_id")
State growthUnbounded without watermark/TTL. Will OOM.Monitor: query.lastProgress['stateOperators'][0]['numRowsTotal']
Watermark cleanupAuto-evicts state when watermark passes window end..withWatermark("event_time", "10 minutes")
TTL cleanupEvicts key if no activity for N minutes (wall clock).state.setTimeoutDuration("30 minutes")
Explicit cleanupRemove key programmatically at any time.state.remove()
mapGroupsWithStateExactly 1 output row per key per batch..mapGroupsWithState(fn, outputSchema, stateSchema, timeoutConf)
flatMapGroupsWithState0 to N output rows per key per batch (use yield)..flatMapGroupsWithState(fn, outputMode="append", ...)
Checkpoint retentionControls how many old batches are kept for recovery.spark.sql.streaming.minBatchesToRetain = 5
๐Ÿ—‚๏ธ
Code Patterns Quick Reference
โ–ผ
Enable RocksDB
spark.conf.set(
  "spark.sql.streaming.
   stateStore.providerClass",
  "...RocksDBStateStoreProvider"
)
Watermark Cleanup
df.withWatermark("event_time", "10 minutes")
  .groupBy(window("event_time","1 hour"))
  .agg(count("*"))
GroupState API
state.exists # bool
state.get # current value
state.update(new_val) # write
state.remove() # delete
state.hasTimedOut # timeout check
state.setTimeoutDuration("30 min")
mapGroupsWithState
df.groupBy("key")
  .mapGroupsWithState(
    fn,
    outputSchema="...",
    stateSchema=StateType,
    timeoutConf=ProcessingTimeTimeout
  )
flatMapGroupsWithState
df.groupBy("key")
  .flatMapGroupsWithState(
    fn, # uses yield
    outputMode="append",
    stateSchema=StateType,
    outputSchema="...",
    timeoutConf=ProcessingTimeTimeout
  )
Monitor State Size
p = query.lastProgress
ops = p["stateOperators"]
ops[0]["numRowsTotal"] # total keys
ops[0]["memoryUsedBytes"]
ops[0]["numRowsUpdated"] # this batch
ops[0]["numRowsRemoved"] # cleanup
๐Ÿ’ก The 3 Golden Rules of State Management
1. Always use watermark or TTL โ€” never let state grow unbounded in production
2. Use RocksDB for production โ€” in-memory state OOMs at scale
3. Never change shuffle.partitions without deleting checkpoint โ€” state partition layout must stay constant
python โ€” Production stateful streaming template
spark = SparkSession.builder \
    .appName("StatefulProd") \
    # 1. Use RocksDB
    .config("spark.sql.streaming.stateStore.providerClass",
            "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") \
    # 2. Right partition count (tune for your cluster)
    .config("spark.sql.shuffle.partitions", "32") \
    # 3. Enable off-heap for RocksDB
    .config("spark.memory.offHeap.enabled", "true") \
    .config("spark.memory.offHeap.size", "2g") \
    # 4. Keep fewer checkpoint versions (default 100 is too many)
    .config("spark.sql.streaming.minBatchesToRetain", "5") \
    .getOrCreate()

# 5. Always use watermark on stateful ops
agg = stream \
    .withWatermark("event_time", "10 minutes") \  # โ† state cleanup
    .groupBy(window("event_time", "1 hour"), "category") \
    .agg(count("*").alias("count"))

query = agg.writeStream \
    .outputMode("append") \
    .option("checkpointLocation", "s3://bucket/ckpt/prod") \
    .trigger(processingTime="30 seconds") \
    .format("delta") \
    .start("s3://bucket/delta/hourly")