Why Test PySpark Code?
Testing ensures your pipelines produce correct, reliable results. Without tests, a small bug in a transformation can silently corrupt millions of records. Learn the four pillars of a complete PySpark test strategy.
Unit testing means testing the smallest piece of code in isolation — typically a single function or transformation. In PySpark, this means testing one transformation function at a time on a tiny, controlled dataset.
clean_phone_numbers(df) that strips non-numeric chars from a phone column. A unit test creates a tiny 3-row DataFrame, passes it in, and checks the output is correct — without touching any real data.
# functions.py — the function we want to test
from pyspark.sql import functions as F
def clean_phone_numbers(df):
"""Remove non-numeric chars from phone column."""
return df.withColumn(
"phone",
F.regexp_replace(F.col("phone"), r"[^0-9]", "")
)
# test_functions.py — the unit test
import pytest
from pyspark.sql import SparkSession
from functions import clean_phone_numbers
@pytest.fixture(scope="session")
def spark():
return SparkSession.builder \
.master("local[1]") \
.appName("unit-test") \
.getOrCreate()
def test_clean_phone_numbers(spark):
# ARRANGE — create tiny test data
input_data = [(1, "(555) 123-4567"), (2, "+1-800-555-0199")]
df = spark.createDataFrame(input_data, ["id", "phone"])
# ACT — call the function
result = clean_phone_numbers(df)
# ASSERT — verify the output
rows = result.collect()
assert rows[0]["phone"] == "5551234567"
assert rows[1]["phone"] == "18005550199"
Integration testing verifies that multiple components work correctly together. For PySpark, this means testing a full pipeline (read → transform → write) with a real (but small and local) dataset, real file formats, or containerized databases.
import os, tempfile
from pyspark.sql import SparkSession
from my_etl_pipeline import run_pipeline
def test_full_pipeline(spark, tmp_path):
# Write sample input data to a temp Parquet file
input_data = [(1, "Alice", 30), (2, "Bob", 25)]
input_df = spark.createDataFrame(input_data, ["id", "name", "age"])
input_path = str(tmp_path / "input")
output_path = str(tmp_path / "output")
input_df.write.parquet(input_path)
# Run the full pipeline
run_pipeline(spark, input_path, output_path)
# Read output and verify
result_df = spark.read.parquet(output_path)
assert result_df.count() == 2
assert "processed_at" in result_df.columns # check audit column added
Regression testing ensures that new code changes do not break existing functionality. You capture expected results as "golden files" or baseline values, and future runs must match them. This is critical in PySpark where a schema change or a refactor can silently change data.
import json
def test_revenue_regression(spark):
# Load your reference "golden" results
with open("tests/golden/revenue_expected.json") as f:
expected = json.load(f)
# Run the actual pipeline
result_df = compute_revenue(spark, "tests/data/sales_sample.parquet")
actual = [row.asDict() for row in result_df.orderBy("month").collect()]
# Compare — any mismatch means a regression
assert actual == expected, f"Regression detected! Diff: {actual}"
# How to create/update the golden file:
# result_df.toPandas().to_json("tests/golden/revenue_expected.json", orient="records")
End-to-End (E2E) testing validates the entire pipeline from source to destination, mimicking real production conditions. It typically runs in a staging environment with real data sources (or realistic replicas), real storage (S3 dev bucket), and real Spark clusters.
Pytest — The Testing Framework
Pytest is the most popular Python testing framework. It's simple to use, powerful with fixtures and plugins, and integrates natively with PySpark testing workflows.
Fixtures are reusable setup/teardown functions decorated with @pytest.fixture. They inject dependencies into test functions. For PySpark, the most important fixture is a shared SparkSession — you create it once and reuse it across all tests in the session, saving huge amounts of time since SparkSession startup is slow.
scope="function" (default, recreated each test), scope="module" (once per file), scope="session" (once per entire test run). For SparkSession always use session scope.
# conftest.py — place this at the root of your tests/ folder
# Pytest auto-discovers this file and makes fixtures available everywhere
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def spark():
"""Create a single SparkSession for the entire test session."""
spark = SparkSession.builder \
.master("local[2]") \
.appName("pytest-pyspark") \
.config("spark.sql.shuffle.partitions", "2") \
.config("spark.default.parallelism", "2") \
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR") # suppress noisy INFO logs
yield spark # test runs here
spark.stop() # cleanup after ALL tests finish
@pytest.fixture
def sample_customers_df(spark):
"""Reusable sample customer DataFrame."""
data = [
(1, "Alice", "alice@example.com", 28),
(2, "Bob", "bob@example.com", 35),
(3, "Carol", "carol@example.com", 22),
]
return spark.createDataFrame(data, ["id", "name", "email", "age"])
# Usage in any test file:
def test_filter_adults(sample_customers_df):
result = sample_customers_df.filter("age >= 25")
assert result.count() == 2 # Alice (28) and Bob (35)
Parameterization lets you run the same test with multiple inputs using one test function. Instead of writing 5 separate test functions for 5 edge cases, you write one and provide the different inputs as a list.
import pytest
from pyspark.sql import functions as F
from my_transforms import categorize_age
# Each tuple is (input_age, expected_category)
@pytest.mark.parametrize("age, expected", [
(10, "child"),
(17, "child"),
(18, "adult"),
(64, "adult"),
(65, "senior"),
(90, "senior"),
])
def test_categorize_age(spark, age, expected):
# Create a one-row DataFrame with this age
df = spark.createDataFrame([(1, age)], ["id", "age"])
result = categorize_age(df) # adds "age_category" column
category = result.collect()[0]["age_category"]
assert category == expected, f"Age {age}: expected '{expected}', got '{category}'"
# Pytest runs this test 6 times automatically — once per parameter set
# Output shows: test_categorize_age[10-child] PASSED
# test_categorize_age[65-senior] PASSED etc.
A well-organized test suite mirrors your project structure and groups tests by type. This makes it easy to run only unit tests during development (fast) and all tests in CI.
├── src/
│ ├── transforms/
│ │ ├── clean.py
│ │ └── aggregate.py
│ └── pipeline.py
├── tests/
│ ├── conftest.py ← shared fixtures (SparkSession here)
│ ├── unit/
│ │ ├── test_clean.py
│ │ └── test_aggregate.py
│ ├── integration/
│ │ └── test_pipeline.py
│ └── data/ ← small sample files
│ └── customers_sample.parquet
└── pytest.ini ← pytest config
[pytest]
testpaths = tests
python_files = test_*.py
python_functions = test_*
# Custom markers to tag test types
markers =
unit: fast isolated unit tests
integration: tests that use real Spark + files
slow: tests that take a long time
# Run only unit tests: pytest -m unit
# Run all tests: pytest
# Run with verbose: pytest -v
import pytest
@pytest.mark.unit
def test_remove_nulls(spark):
df = spark.createDataFrame([(1, None), (2, "Alice")], ["id", "name"])
result = df.dropna()
assert result.count() == 1
@pytest.mark.integration
def test_read_parquet_and_transform(spark, tmp_path):
# This test reads a real parquet file
...
Setting Up Spark for Testing
Learn how to create a lightweight, fast SparkSession specifically for testing — with the right configs to minimize overhead and maximize test speed.
In PySpark testing, "mocking" means creating a local SparkSession that runs entirely on your laptop (or CI machine) without needing a real cluster. You use master("local[*]") to tell Spark to use local threads. You also mock external dependencies (S3, JDBC) using test doubles or temp directories.
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def spark():
"""
Lightweight SparkSession for unit/integration tests.
Configs tuned for test speed, not production throughput.
"""
spark = (
SparkSession.builder
.master("local[2]") # 2 threads, enough for tests
.appName("pyspark-test-suite")
# Reduce shuffle partitions (default 200 is too many for tiny test data)
.config("spark.sql.shuffle.partitions", "2")
.config("spark.default.parallelism", "2")
# Disable adaptive query (can cause flakiness in tiny datasets)
.config("spark.sql.adaptive.enabled", "false")
# Use in-memory for checkpoint (no real HDFS needed)
.config("spark.sql.streaming.checkpointLocation", "/tmp/spark-test-checkpoint")
# Enable Delta if needed in tests
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config("spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog")
.getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR") # silence noisy logs
yield spark
spark.stop()
Local mode Spark runs driver and executors in a single JVM process. It supports all Spark features — DataFrames, Spark SQL, UDFs, even streaming — but all on your machine. Key patterns for local test data management:
import os
from pathlib import Path
# Pattern 1: Inline test data (best for unit tests)
def test_transform_inline(spark):
df = spark.createDataFrame([
(1, "apple", 100),
(2, "banana", 200),
], ["id", "product", "revenue"])
result = df.filter("revenue > 150")
assert result.count() == 1
# Pattern 2: Test files from disk (best for integration tests)
TEST_DATA_DIR = Path(__file__).parent / "data"
def test_read_csv(spark):
df = spark.read.csv(str(TEST_DATA_DIR / "orders.csv"), header=True, inferSchema=True)
assert df.count() > 0
assert "order_id" in df.columns
# Pattern 3: Temp directory (pytest's tmp_path fixture)
def test_write_and_read_parquet(spark, tmp_path):
df = spark.range(10) # creates DataFrame with id column 0-9
out = str(tmp_path / "output")
df.write.parquet(out)
result = spark.read.parquet(out)
assert result.count() == 10
# tmp_path is automatically cleaned up by pytest after the test
Create reusable fixtures for common test DataFrames — customer data, order data, product data — so every test file can use them without re-writing setup code.
# conftest.py
from pyspark.sql.types import (StructType, StructField, IntegerType,
StringType, DoubleType, DateType)
from datetime import date
@pytest.fixture
def orders_df(spark):
"""Sample orders DataFrame with explicit schema."""
schema = StructType([
StructField("order_id", IntegerType(), nullable=False),
StructField("customer_id", IntegerType(), nullable=True),
StructField("amount", DoubleType(), nullable=True),
StructField("status", StringType(), nullable=True),
StructField("order_date", DateType(), nullable=True),
])
data = [
(1001, 1, 250.0, "COMPLETED", date(2024, 1, 15)),
(1002, 2, 89.99, "PENDING", date(2024, 1, 16)),
(1003, 1, 500.0, "COMPLETED", date(2024, 1, 17)),
(1004, 3, None, "FAILED", date(2024, 1, 18)),
]
return spark.createDataFrame(data, schema)
@pytest.fixture
def products_df(spark):
data = [(1, "Laptop", 999.99), (2, "Mouse", 29.99)]
return spark.createDataFrame(data, ["product_id", "name", "price"])
# A test using multiple fixtures:
def test_orders_are_all_positive(orders_df):
# Test that all non-null amounts are positive
negatives = orders_df.filter("amount <= 0")
assert negatives.count() == 0
def test_completed_orders_count(orders_df):
completed = orders_df.filter("status = 'COMPLETED'")
assert completed.count() == 2
Testing DataFrames Thoroughly
Learn the four essential DataFrame checks: schema validation, row count validation, data equality testing, and column validation. These form the backbone of any PySpark test.
Schema validation ensures your transformation produces the correct columns and data types. A schema mismatch (e.g., a column becoming StringType instead of IntegerType) can corrupt downstream aggregations silently.
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
def test_output_schema(spark):
# Define what the schema SHOULD look like
expected_schema = StructType([
StructField("customer_id", IntegerType(), nullable=True),
StructField("full_name", StringType(), nullable=True),
StructField("total_spend", DoubleType(), nullable=True),
])
input_df = spark.createDataFrame([
(1, "Alice", "Smith", 300.0),
(2, "Bob", "Jones", 150.5),
], ["customer_id", "first_name", "last_name", "total_spend"])
result = transform_customers(input_df) # should concat first+last into full_name
# Method 1: Compare full schemas
assert result.schema == expected_schema
# Method 2: Check column names only
assert result.columns == ["customer_id", "full_name", "total_spend"]
# Method 3: Check specific column type
schema_dict = {f.name: f.dataType for f in result.schema.fields}
assert isinstance(schema_dict["total_spend"], DoubleType)
# Method 4: Check no extra unexpected columns
expected_cols = {"customer_id", "full_name", "total_spend"}
actual_cols = set(result.columns)
assert actual_cols == expected_cols, f"Unexpected columns: {actual_cols - expected_cols}"
Row count validation ensures your transformation doesn't silently drop or duplicate rows. It's the first check to run because if the count is wrong, nothing else matters.
def test_no_rows_lost_in_join(spark):
customers = spark.createDataFrame(
[(1, "Alice"), (2, "Bob"), (3, "Carol")],
["cust_id", "name"]
)
orders = spark.createDataFrame(
[(1, 100.0), (1, 200.0), (2, 150.0)],
["cust_id", "amount"]
)
# Left join — Carol has no orders but should still appear
result = customers.join(orders, on="cust_id", how="left")
# 3 customers × up to 2 orders each = 4 rows (Carol gets 1 null row)
assert result.count() == 4
# Verify Carol appears with null amount
carol_rows = result.filter("name = 'Carol'")
assert carol_rows.count() == 1
assert carol_rows.collect()[0]["amount"] is None
def test_dedup_reduces_rows(spark):
df = spark.createDataFrame(
[(1, "A"), (1, "A"), (2, "B")],
["id", "val"]
)
result = df.dropDuplicates()
assert result.count() == 2 # was 3, should be 2 after dedup
def test_filter_result_count(orders_df):
completed = orders_df.filter("status = 'COMPLETED'")
# Exact count
assert completed.count() == 2
# Range check (useful when data size isn't fully known)
total = orders_df.count()
assert 0 < completed.count() < total # some, but not all
Data equality testing checks that every row and column value matches the expected result exactly. DataFrames have no concept of row order, so you must sort before comparing.
df1.collect() == df2.collect() without sorting first. Spark DataFrames have no guaranteed row order — the same logical data can appear in different orders on different runs.
from pyspark.sql import functions as F
def assert_df_equals(df1, df2, sort_by):
"""
Compare two DataFrames for equality.
- Same schema
- Same row count
- Same data (order-independent via sort)
"""
# Check schemas match
assert df1.schema == df2.schema, f"Schema mismatch:\n{df1.schema}\nvs\n{df2.schema}"
# Check row counts match
count1, count2 = df1.count(), df2.count()
assert count1 == count2, f"Row count mismatch: {count1} vs {count2}"
# Sort both and compare row by row
rows1 = df1.orderBy(sort_by).collect()
rows2 = df2.orderBy(sort_by).collect()
assert rows1 == rows2, "DataFrame contents do not match!"
def test_aggregate_result(spark):
orders = spark.createDataFrame([
(1, "A", 100.0), (2, "A", 200.0),
(3, "B", 150.0),
], ["order_id", "category", "amount"])
result = orders.groupBy("category").agg(
F.sum("amount").alias("total_amount"),
F.count("*").alias("order_count")
)
expected = spark.createDataFrame([
("A", 300.0, 2),
("B", 150.0, 1),
], ["category", "total_amount", "order_count"])
assert_df_equals(result, expected, sort_by="category")
Column validation goes beyond counts — it checks specific column values, null rates, value ranges, and allowed categorical values.
from pyspark.sql import functions as F
def test_column_values(spark):
df = spark.createDataFrame([
(1, "COMPLETED", 100.0),
(2, "PENDING", 50.0),
(3, "FAILED", 0.0),
], ["id", "status", "amount"])
# 1. Check allowed values (enum check)
allowed_statuses = {"COMPLETED", "PENDING", "FAILED", "CANCELLED"}
actual_statuses = {row["status"] for row in df.select("status").distinct().collect()}
assert actual_statuses.issubset(allowed_statuses), f"Invalid statuses: {actual_statuses - allowed_statuses}"
# 2. Check no nulls in critical column
null_count = df.filter(F.col("id").isNull()).count()
assert null_count == 0, "id column must not have nulls"
# 3. Check value range
min_amount, max_amount = df.agg(
F.min("amount"), F.max("amount")
).collect()[0]
assert min_amount >= 0, "Amount cannot be negative"
assert max_amount < 1_000_000, "Amount seems unreasonably large"
# 4. Uniqueness check
total_rows = df.count()
distinct_ids = df.select("id").distinct().count()
assert total_rows == distinct_ids, "id column must be unique"
# 5. Non-empty string check
empty_status = df.filter(F.trim(F.col("status")) == "").count()
assert empty_status == 0, "status must not be empty string"
Chispa — Elegant DataFrame Assertions
Chispa is a lightweight Python library that provides elegant, readable assertion functions specifically for PySpark DataFrames. It handles the boilerplate of schema checking, sorting, and equality comparison.
pip install chispa — Works with PySpark 2.x, 3.x, 4.x
The main function in Chispa. It compares two DataFrames and raises a descriptive error message if they differ — showing exactly which rows are different, which is much more helpful than a plain assertion failure.
from chispa.dataframe_comparer import assert_df_equality
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
def test_with_chispa_basic(spark):
actual = spark.createDataFrame(
[(1, "Alice"), (2, "Bob")],
["id", "name"]
)
expected = spark.createDataFrame(
[(2, "Bob"), (1, "Alice")], # different order — chispa handles this
["id", "name"]
)
# ignore_row_order=True — sorts before comparing (recommended for most tests)
assert_df_equality(actual, expected, ignore_row_order=True)
def test_with_chispa_ignore_nullability(spark):
"""Sometimes schema nullable flag differs but data is the same."""
actual = spark.createDataFrame(
[(1, "Alice")],
StructType([StructField("id", IntegerType(), nullable=True),
StructField("name", StringType(), nullable=True)])
)
expected = spark.createDataFrame(
[(1, "Alice")],
StructType([StructField("id", IntegerType(), nullable=False), # different
StructField("name", StringType(), nullable=False)]) # nullable flag
)
# ignore_nullable=True skips nullable flag comparison
assert_df_equality(actual, expected, ignore_nullable=True)
# When there's a mismatch, Chispa shows a clear diff:
# DataFramesNotEqualError:
# DF1 row: Row(id=1, name='Alice')
# DF2 row: Row(id=1, name='ALICE') ← you see exactly what differs
Floating point arithmetic introduces tiny rounding differences. 100.0 / 3 might produce 33.333333333333336 vs 33.33333333333333. Strict equality fails. Chispa's approximate equality checks values within a tolerance.
from chispa.dataframe_comparer import assert_approx_df_equality
def test_average_calculation(spark):
# Compute an average — floating point result
df = spark.createDataFrame(
[(1, 100.0), (2, 200.0), (3, 300.0)],
["id", "revenue"]
)
from pyspark.sql import functions as F
result = df.agg(F.avg("revenue").alias("avg_revenue"))
expected = spark.createDataFrame(
[(200.00000001,)], # slightly off due to floating point
["avg_revenue"]
)
# precision=1e-3 means values within 0.001 of each other pass
assert_approx_df_equality(result, expected, precision=1e-3)
def test_percentage_columns(spark):
"""Good for testing percentage/ratio columns."""
result = compute_market_share(spark) # returns % values
expected = spark.createDataFrame(
[("Product A", 33.33), ("Product B", 66.67)],
["product", "share_pct"]
)
assert_approx_df_equality(result, expected,
precision=0.01, # within 0.01%
ignore_row_order=True)
Sometimes you only want to validate the schema — not the data — for example after a column rename or type cast operation. Chispa has a dedicated schema assertion function.
from chispa.schema_comparer import assert_schema_equality
from pyspark.sql.types import *
def test_after_cast(spark):
df = spark.createDataFrame(
[("1", "100.5")],
["user_id", "amount"]
)
# Cast string columns to correct types
result = df.withColumn("user_id", df["user_id"].cast(IntegerType())) \
.withColumn("amount", df["amount"].cast(DoubleType()))
expected_schema = StructType([
StructField("user_id", IntegerType(), nullable=True),
StructField("amount", DoubleType(), nullable=True),
])
assert_schema_equality(result.schema, expected_schema)
# Ignore nullable flags if they don't matter for your test:
from chispa.schema_comparer import assert_schema_equality_ignore_nullable
def test_schema_ignoring_nullable(spark, result_df):
expected_schema = StructType([
StructField("id", IntegerType(), nullable=False),
StructField("name", StringType(), nullable=False),
])
# Won't fail if actual has nullable=True
assert_schema_equality_ignore_nullable(result_df.schema, expected_schema)
Testcontainers — Real Services in Tests
Testcontainers is a Python library that spins up real Docker containers (PostgreSQL, Kafka, MySQL) during your tests and tears them down afterward. This lets you test against real databases and brokers without needing a shared dev environment.
pip install testcontainers[postgres] testcontainers[kafka] — Requires Docker to be running.
With Testcontainers, you get a real database, real Kafka — not mocks. This means your integration test catches bugs that only appear with real services (SQL dialect differences, serialization issues, etc.).
Use a real PostgreSQL container to test your Spark JDBC read/write code. This catches real issues — type mapping errors, null handling, character encoding — that mocks miss.
import pytest
from testcontainers.postgres import PostgresContainer
import psycopg2
@pytest.fixture(scope="session")
def postgres_container():
"""Start a real PostgreSQL container for the test session."""
with PostgresContainer("postgres:15") as postgres:
yield postgres # container is running during tests
# Container auto-destroyed after yield block exits
@pytest.fixture(scope="session")
def seeded_postgres(postgres_container):
"""Create tables and insert test data into the container."""
conn = psycopg2.connect(postgres_container.get_connection_url())
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS orders (
order_id INT PRIMARY KEY,
customer VARCHAR(100),
amount NUMERIC(10, 2),
status VARCHAR(20)
)
""")
cursor.execute("""
INSERT INTO orders VALUES
(1, 'Alice', 250.00, 'COMPLETED'),
(2, 'Bob', 89.99, 'PENDING'),
(3, 'Carol', 500.00, 'COMPLETED')
""")
conn.commit()
cursor.close()
conn.close()
return postgres_container
def test_spark_reads_postgres(spark, seeded_postgres):
# Build JDBC URL from the container
jdbc_url = seeded_postgres.get_connection_url().replace(
"postgresql", "jdbc:postgresql"
)
# Read the table using Spark JDBC
df = spark.read.format("jdbc") \
.option("url", jdbc_url) \
.option("dbtable", "orders") \
.option("user", seeded_postgres.POSTGRES_USER) \
.option("password", seeded_postgres.POSTGRES_PASSWORD) \
.option("driver", "org.postgresql.Driver") \
.load()
# Verify data
assert df.count() == 3
completed = df.filter("status = 'COMPLETED'")
assert completed.count() == 2
total = df.agg({"amount": "sum"}).collect()[0][0]
assert abs(total - 839.99) < 0.01
Testcontainers provides a real Kafka broker in a container. You can produce test messages, run a Spark Streaming query in micro-batch mode, and assert on the output — all locally.
import pytest, json
from testcontainers.kafka import KafkaContainer
from kafka import KafkaProducer
@pytest.fixture(scope="session")
def kafka_container():
with KafkaContainer() as kafka:
yield kafka
def test_streaming_from_kafka(spark, kafka_container, tmp_path):
bootstrap = kafka_container.get_bootstrap_server()
topic = "test-orders"
# Produce 3 test messages to Kafka
producer = KafkaProducer(
bootstrap_servers=bootstrap,
value_serializer=lambda v: json.dumps(v).encode()
)
for i in range(3):
producer.send(topic, {"order_id": i, "amount": i * 100})
producer.flush()
# Read from Kafka with Spark Streaming (batch mode)
df = spark.read \
.format("kafka") \
.option("kafka.bootstrap.servers", bootstrap) \
.option("subscribe", topic) \
.option("startingOffsets", "earliest") \
.load()
# Parse JSON values
from pyspark.sql.types import StructType, StructField, IntegerType
from pyspark.sql import functions as F
schema = StructType([
StructField("order_id", IntegerType()),
StructField("amount", IntegerType()),
])
parsed = df.select(F.from_json(F.col("value").cast("string"), schema).alias("data")) \
.select("data.*")
assert parsed.count() == 3
total = parsed.agg({"amount": "sum"}).collect()[0][0]
assert total == 300 # 0 + 100 + 200
Data Quality Frameworks
Data quality testing uses specialized frameworks to validate data at scale. Beyond simple assertions, these tools define reusable "expectations" and "rules" that can be run in CI pipelines and production alike.
Great Expectations is the most widely-used Python data quality framework. You define expectations (e.g., "column X must not have nulls"), run them as a checkpoint, and get an HTML data quality report. It integrates with Spark natively.
import great_expectations as gx
from great_expectations.dataset import SparkDFDataset
def test_orders_quality(spark):
# Create a DataFrame
df = spark.createDataFrame([
(1, "Alice", "COMPLETED", 250.0),
(2, "Bob", "PENDING", 89.99),
(3, "Carol", "COMPLETED", 500.0),
], ["order_id", "customer", "status", "amount"])
# Wrap in GX Spark Dataset
ge_df = SparkDFDataset(df)
# Define expectations
# 1. order_id must not have nulls
result = ge_df.expect_column_values_to_not_be_null("order_id")
assert result.success
# 2. status must be one of the allowed values
result = ge_df.expect_column_values_to_be_in_set(
"status", {"COMPLETED", "PENDING", "FAILED", "CANCELLED"}
)
assert result.success
# 3. amount must be positive
result = ge_df.expect_column_values_to_be_between(
"amount", min_value=0.01, max_value=100_000
)
assert result.success
# 4. order_id must be unique
result = ge_df.expect_column_values_to_be_unique("order_id")
assert result.success
# 5. Row count expectation
result = ge_df.expect_table_row_count_to_be_between(min_value=1, max_value=10_000)
assert result.success
# 6. Column existence
result = ge_df.expect_column_to_exist("customer")
assert result.success
Deequ is an open-source data quality library built by Amazon, implemented on top of Apache Spark. It defines constraints (rules) and analyzers (metrics). PyDeequ is the Python wrapper.
# pip install pydeequ
from pydeequ.checks import Check, CheckLevel
from pydeequ.verification import VerificationSuite, VerificationResult
def test_with_deequ(spark):
df = spark.createDataFrame([
(1, "Alice", 28, "alice@email.com"),
(2, "Bob", 35, "bob@email.com"),
(3, None, 22, "carol@email.com"),
], ["id", "name", "age", "email"])
check = (Check(spark, CheckLevel.Error, "Order Quality")
.isComplete("id") # id has no nulls
.isUnique("id") # id is unique
.isNonNegative("age") # age >= 0
.isContainedIn("name", ["Alice", "Bob", "Carol"]) # allowed names
.hasCompleteness("name", lambda p: p >= 0.8) # 80%+ non-null
)
result = VerificationSuite(spark).onData(df).addChecks([check]).run()
df_result = VerificationResult.checkResultsAsDataFrame(spark, result)
# Show detailed results
df_result.show(truncate=False)
# Fail test if any constraint failed
failed = df_result.filter("constraint_status = 'Failure'")
assert failed.count() == 0, f"DQ constraints failed!"
Soda defines data quality checks in YAML (SodaCL language) — no Python code needed for checks. Soda Core runs checks against your Spark DataFrames or SQL databases and reports pass/fail.
# checks.yml
checks for orders:
# Row count must be between 100 and 1 million
- row_count:
between [100, 1000000]
# order_id must have zero nulls
- missing_count(order_id) = 0
# status only allows these values
- invalid_count(status):
valid values: [COMPLETED, PENDING, FAILED, CANCELLED]
must be 0
# amount must be positive
- min(amount) > 0
# customer_email must match regex
- invalid_count(customer_email):
valid regex: '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
must be 0
# Duplicate check
- duplicate_count(order_id) = 0
# pip install soda-core-spark-df
from soda.scan import Scan
def test_soda_quality(spark):
df = spark.read.parquet("data/orders.parquet")
scan = Scan()
scan.set_scan_definition_name("orders-test")
scan.set_data_source_name("spark_df")
scan.add_spark_session(spark)
# Register the DataFrame under the name "orders"
scan.add_dataframe_table(df, "orders")
# Add the YAML check file
scan.add_sodacl_yaml_file("checks/checks.yml")
# Run checks
scan.execute()
# Assert no failures
assert scan.get_checks_fail_count() == 0, \
f"Soda checks failed:\n{scan.get_scan_results()}"
When GX/Deequ/Soda feel heavy, you can build a lightweight custom validation framework using plain Python and PySpark. Define rules as functions, return a results dictionary, and assert all rules pass.
from dataclasses import dataclass
from typing import List, Callable
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
@dataclass
class ValidationRule:
name: str
check: Callable[[DataFrame], bool]
description: str
class DataValidator:
def __init__(self, df: DataFrame):
self.df = df
self.rules: List[ValidationRule] = []
def add_rule(self, name, check, description=""):
self.rules.append(ValidationRule(name, check, description))
return self # allow chaining
def validate(self):
results = []
for rule in self.rules:
try:
passed = rule.check(self.df)
results.append({"rule": rule.name, "passed": passed,
"desc": rule.description, "error": None})
except Exception as e:
results.append({"rule": rule.name, "passed": False,
"desc": rule.description, "error": str(e)})
return results
# Using the custom validator:
def test_custom_validator(spark, orders_df):
results = (
DataValidator(orders_df)
.add_rule(
"no_null_order_ids",
lambda df: df.filter(F.col("order_id").isNull()).count() == 0,
"Order IDs must not be null"
)
.add_rule(
"positive_amounts",
lambda df: df.filter("amount <= 0").count() == 0,
"All amounts must be positive"
)
.add_rule(
"valid_statuses",
lambda df: df.filter(
~F.col("status").isin(["COMPLETED", "PENDING", "FAILED"])
).count() == 0,
"Status must be one of allowed values"
)
.validate()
)
failures = [r for r in results if not r["passed"]]
assert not failures, f"Validation failed: {failures}"
Test Your Understanding
7 questions covering all topics in Module 25. Select an answer to see the explanation.
Testing PySpark — Cheat Sheet
Quick-reference for all key patterns, commands, and code snippets from Module 25.
def spark():
s = SparkSession.builder
.master("local[2]")
.config("spark.sql.
shuffle.partitions","2")
.getOrCreate()
yield s
s.stop()
"age, expected", [
(10, "child"),
(18, "adult"),
(65, "senior"),
])
def test_cat(spark, age, expected):
...
import assert_df_equality
assert_df_equality(
actual, expected,
ignore_row_order=True,
ignore_nullable=True
)
import assert_approx_df_equality
assert_approx_df_equality(
actual, expected,
precision=1e-3,
ignore_row_order=True
)
assert df.columns == ["a","b","c"]
# Data types
schema_dict = {
f.name: f.dataType
for f in df.schema.fields
}
assert isinstance(
schema_dict["age"], IntegerType)
assert df.filter(col("id")
.isNull()).count() == 0
# Unique
assert df.count() ==
df.select("id").distinct().count()
# Value range
assert df.filter("amount<0")
.count() == 0
import PostgresContainer
@pytest.fixture(scope="session")
def pg():
with PostgresContainer("postgres:15")
as pg:
yield pg
r = ge_df.expect_column_values
_to_not_be_null("id")
assert r.success
r = ge_df.expect_column_values
_to_be_in_set("status",
{"COMPLETED","PENDING"})
assert r.success
| Tool | Type | Best For | Language | Spark Native |
|---|---|---|---|---|
| Pytest | Test runner | All Python testing | Python | ✅ Yes |
| Chispa | DataFrame assertions | Comparing DataFrames elegantly | Python | ✅ PySpark only |
| Testcontainers | Integration testing | Real DB/Kafka in tests | Python | 🔶 Via JDBC/Kafka |
| Great Expectations | Data quality | Expectation suites, DQ reports | Python/YAML | ✅ Yes |
| PyDeequ | Data quality | AWS-native constraint checking | Python | ✅ Built on Spark |
| Soda | Data quality | YAML-driven checks, no-code DQ | YAML/Python | ✅ Yes |
| What you're testing | Test Type | Tool |
|---|---|---|
| A single transform function | Unit Test | pytest + inline data |
| Read Parquet → transform → write Parquet | Integration Test | pytest + tmp_path |
| Schema and types after cast | Schema Check | chispa assert_schema_equality |
| Float averages/percentages | Approx Equality | chispa assert_approx_df_equality |
| Read/write from real PostgreSQL | Containerized | testcontainers[postgres] |
| Data in production passes DQ rules | DQ Check | Great Expectations / Soda |
| Same code produces same results as before | Regression | pytest + golden files |