MODULE 25 Testing PySpark
1 / 9
25.1 — Testing Fundamentals

Why Test PySpark Code?

Testing ensures your pipelines produce correct, reliable results. Without tests, a small bug in a transformation can silently corrupt millions of records. Learn the four pillars of a complete PySpark test strategy.

💡 Real-World Analogy
Think of your PySpark job like a factory assembly line. Unit tests check each machine individually (does this function work?). Integration tests check if machines work together (does the whole line produce the right output?). Regression tests ensure fixing one machine doesn't break another. End-to-End tests simulate the full product from raw material to finished goods.
🔬
Unit Testing FOUNDATION
What is Unit Testing?

Unit testing means testing the smallest piece of code in isolation — typically a single function or transformation. In PySpark, this means testing one transformation function at a time on a tiny, controlled dataset.

Key Principle
A unit test should run fast (milliseconds), test one thing only, and have no external dependencies (no real databases, no real S3).
Example Scenario
You write a function clean_phone_numbers(df) that strips non-numeric chars from a phone column. A unit test creates a tiny 3-row DataFrame, passes it in, and checks the output is correct — without touching any real data.
python — unit test example
# functions.py — the function we want to test
from pyspark.sql import functions as F

def clean_phone_numbers(df):
    """Remove non-numeric chars from phone column."""
    return df.withColumn(
        "phone",
        F.regexp_replace(F.col("phone"), r"[^0-9]", "")
    )


# test_functions.py — the unit test
import pytest
from pyspark.sql import SparkSession
from functions import clean_phone_numbers

@pytest.fixture(scope="session")
def spark():
    return SparkSession.builder \
        .master("local[1]") \
        .appName("unit-test") \
        .getOrCreate()

def test_clean_phone_numbers(spark):
    # ARRANGE — create tiny test data
    input_data = [(1, "(555) 123-4567"), (2, "+1-800-555-0199")]
    df = spark.createDataFrame(input_data, ["id", "phone"])

    # ACT — call the function
    result = clean_phone_numbers(df)

    # ASSERT — verify the output
    rows = result.collect()
    assert rows[0]["phone"] == "5551234567"
    assert rows[1]["phone"] == "18005550199"
🔗
Integration Testing IMPORTANT
What is Integration Testing?

Integration testing verifies that multiple components work correctly together. For PySpark, this means testing a full pipeline (read → transform → write) with a real (but small and local) dataset, real file formats, or containerized databases.

Example Scenario
Test that your ETL pipeline can read a Parquet file, apply transformations, and write the correct result to a temp output path — all in a local Spark session.
python — integration test
import os, tempfile
from pyspark.sql import SparkSession
from my_etl_pipeline import run_pipeline

def test_full_pipeline(spark, tmp_path):
    # Write sample input data to a temp Parquet file
    input_data = [(1, "Alice", 30), (2, "Bob", 25)]
    input_df = spark.createDataFrame(input_data, ["id", "name", "age"])
    
    input_path = str(tmp_path / "input")
    output_path = str(tmp_path / "output")
    input_df.write.parquet(input_path)

    # Run the full pipeline
    run_pipeline(spark, input_path, output_path)

    # Read output and verify
    result_df = spark.read.parquet(output_path)
    assert result_df.count() == 2
    assert "processed_at" in result_df.columns  # check audit column added
🔁
Regression Testing STABILITY
What is Regression Testing?

Regression testing ensures that new code changes do not break existing functionality. You capture expected results as "golden files" or baseline values, and future runs must match them. This is critical in PySpark where a schema change or a refactor can silently change data.

python — regression / golden file test
import json

def test_revenue_regression(spark):
    # Load your reference "golden" results
    with open("tests/golden/revenue_expected.json") as f:
        expected = json.load(f)

    # Run the actual pipeline
    result_df = compute_revenue(spark, "tests/data/sales_sample.parquet")
    actual = [row.asDict() for row in result_df.orderBy("month").collect()]

    # Compare — any mismatch means a regression
    assert actual == expected, f"Regression detected! Diff: {actual}"


# How to create/update the golden file:
# result_df.toPandas().to_json("tests/golden/revenue_expected.json", orient="records")
⚠️ Watch Out
Update golden files intentionally and commit the change to Git with a clear message. Accidentally overwriting them defeats the purpose of regression testing.
🌐
End-to-End Testing PIPELINE
What is End-to-End Testing?

End-to-End (E2E) testing validates the entire pipeline from source to destination, mimicking real production conditions. It typically runs in a staging environment with real data sources (or realistic replicas), real storage (S3 dev bucket), and real Spark clusters.

📁 Source CSV/Kafka ⚡ Spark Job 🗄️ Delta/S3 Output ✅ Row Count + Schema Check
Unit Tests
Milliseconds. Run on every commit. Hundreds of tests.
🔗
Integration Tests
Seconds to minutes. Run on PR merge. Dozens of tests.
🌐
E2E Tests
Minutes to hours. Run before prod deploy. Few key tests.
25.2 — Pytest

Pytest — The Testing Framework

Pytest is the most popular Python testing framework. It's simple to use, powerful with fixtures and plugins, and integrates natively with PySpark testing workflows.

🔧
Fixtures CORE CONCEPT
What are Fixtures?

Fixtures are reusable setup/teardown functions decorated with @pytest.fixture. They inject dependencies into test functions. For PySpark, the most important fixture is a shared SparkSession — you create it once and reuse it across all tests in the session, saving huge amounts of time since SparkSession startup is slow.

Scope Options
scope="function" (default, recreated each test), scope="module" (once per file), scope="session" (once per entire test run). For SparkSession always use session scope.
python — conftest.py (shared fixtures)
# conftest.py — place this at the root of your tests/ folder
# Pytest auto-discovers this file and makes fixtures available everywhere

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    """Create a single SparkSession for the entire test session."""
    spark = SparkSession.builder \
        .master("local[2]") \
        .appName("pytest-pyspark") \
        .config("spark.sql.shuffle.partitions", "2") \
        .config("spark.default.parallelism", "2") \
        .getOrCreate()
    
    spark.sparkContext.setLogLevel("ERROR")  # suppress noisy INFO logs
    yield spark  # test runs here
    spark.stop()  # cleanup after ALL tests finish


@pytest.fixture
def sample_customers_df(spark):
    """Reusable sample customer DataFrame."""
    data = [
        (1, "Alice", "alice@example.com", 28),
        (2, "Bob",   "bob@example.com",   35),
        (3, "Carol", "carol@example.com", 22),
    ]
    return spark.createDataFrame(data, ["id", "name", "email", "age"])


# Usage in any test file:
def test_filter_adults(sample_customers_df):
    result = sample_customers_df.filter("age >= 25")
    assert result.count() == 2  # Alice (28) and Bob (35)
📋
Parameterization EFFICIENCY
What is Parameterization?

Parameterization lets you run the same test with multiple inputs using one test function. Instead of writing 5 separate test functions for 5 edge cases, you write one and provide the different inputs as a list.

python — parameterized tests
import pytest
from pyspark.sql import functions as F
from my_transforms import categorize_age

# Each tuple is (input_age, expected_category)
@pytest.mark.parametrize("age, expected", [
    (10, "child"),
    (17, "child"),
    (18, "adult"),
    (64, "adult"),
    (65, "senior"),
    (90, "senior"),
])
def test_categorize_age(spark, age, expected):
    # Create a one-row DataFrame with this age
    df = spark.createDataFrame([(1, age)], ["id", "age"])
    
    result = categorize_age(df)  # adds "age_category" column
    category = result.collect()[0]["age_category"]
    
    assert category == expected, f"Age {age}: expected '{expected}', got '{category}'"

# Pytest runs this test 6 times automatically — once per parameter set
# Output shows: test_categorize_age[10-child] PASSED
#               test_categorize_age[65-senior] PASSED   etc.
📁
Test Organization BEST PRACTICE
How to Organize Your Tests

A well-organized test suite mirrors your project structure and groups tests by type. This makes it easy to run only unit tests during development (fast) and all tests in CI.

my_project/
├── src/
│ ├── transforms/
│ │ ├── clean.py
│ │ └── aggregate.py
│ └── pipeline.py
├── tests/
│ ├── conftest.py ← shared fixtures (SparkSession here)
│ ├── unit/
│ │ ├── test_clean.py
│ │ └── test_aggregate.py
│ ├── integration/
│ │ └── test_pipeline.py
│ └── data/ ← small sample files
│ └── customers_sample.parquet
└── pytest.ini ← pytest config
ini — pytest.ini configuration
[pytest]
testpaths = tests
python_files = test_*.py
python_functions = test_*

# Custom markers to tag test types
markers =
    unit: fast isolated unit tests
    integration: tests that use real Spark + files
    slow: tests that take a long time

# Run only unit tests: pytest -m unit
# Run all tests:       pytest
# Run with verbose:    pytest -v
python — using markers
import pytest

@pytest.mark.unit
def test_remove_nulls(spark):
    df = spark.createDataFrame([(1, None), (2, "Alice")], ["id", "name"])
    result = df.dropna()
    assert result.count() == 1

@pytest.mark.integration
def test_read_parquet_and_transform(spark, tmp_path):
    # This test reads a real parquet file
    ...
25.3 — Spark Testing

Setting Up Spark for Testing

Learn how to create a lightweight, fast SparkSession specifically for testing — with the right configs to minimize overhead and maximize test speed.

🎭
Mock SparkSession SETUP
What does "Mock" mean for Spark?

In PySpark testing, "mocking" means creating a local SparkSession that runs entirely on your laptop (or CI machine) without needing a real cluster. You use master("local[*]") to tell Spark to use local threads. You also mock external dependencies (S3, JDBC) using test doubles or temp directories.

💡 Analogy
A mock SparkSession is like a film rehearsal on a small stage. The actors practice all their moves with fake props. The real production is the actual cluster, but rehearsal catches 95% of problems cheaply.
python — optimized test SparkSession
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    """
    Lightweight SparkSession for unit/integration tests.
    Configs tuned for test speed, not production throughput.
    """
    spark = (
        SparkSession.builder
        .master("local[2]")           # 2 threads, enough for tests
        .appName("pyspark-test-suite")
        # Reduce shuffle partitions (default 200 is too many for tiny test data)
        .config("spark.sql.shuffle.partitions", "2")
        .config("spark.default.parallelism", "2")
        # Disable adaptive query (can cause flakiness in tiny datasets)
        .config("spark.sql.adaptive.enabled", "false")
        # Use in-memory for checkpoint (no real HDFS needed)
        .config("spark.sql.streaming.checkpointLocation", "/tmp/spark-test-checkpoint")
        # Enable Delta if needed in tests
        .config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0")
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
        .config("spark.sql.catalog.spark_catalog", 
                "org.apache.spark.sql.delta.catalog.DeltaCatalog")
        .getOrCreate()
    )
    spark.sparkContext.setLogLevel("ERROR")  # silence noisy logs
    yield spark
    spark.stop()
💻
Local Spark Testing PATTERNS
Working with Local Mode for Tests

Local mode Spark runs driver and executors in a single JVM process. It supports all Spark features — DataFrames, Spark SQL, UDFs, even streaming — but all on your machine. Key patterns for local test data management:

python — local test patterns
import os
from pathlib import Path

# Pattern 1: Inline test data (best for unit tests)
def test_transform_inline(spark):
    df = spark.createDataFrame([
        (1, "apple", 100),
        (2, "banana", 200),
    ], ["id", "product", "revenue"])
    
    result = df.filter("revenue > 150")
    assert result.count() == 1


# Pattern 2: Test files from disk (best for integration tests)
TEST_DATA_DIR = Path(__file__).parent / "data"

def test_read_csv(spark):
    df = spark.read.csv(str(TEST_DATA_DIR / "orders.csv"), header=True, inferSchema=True)
    assert df.count() > 0
    assert "order_id" in df.columns


# Pattern 3: Temp directory (pytest's tmp_path fixture)
def test_write_and_read_parquet(spark, tmp_path):
    df = spark.range(10)  # creates DataFrame with id column 0-9
    
    out = str(tmp_path / "output")
    df.write.parquet(out)
    
    result = spark.read.parquet(out)
    assert result.count() == 10
    # tmp_path is automatically cleaned up by pytest after the test
🔩
Spark Test Fixtures REUSABILITY
Building a Library of Test Fixtures

Create reusable fixtures for common test DataFrames — customer data, order data, product data — so every test file can use them without re-writing setup code.

python — rich fixture library in conftest.py
# conftest.py
from pyspark.sql.types import (StructType, StructField, IntegerType,
                                   StringType, DoubleType, DateType)
from datetime import date

@pytest.fixture
def orders_df(spark):
    """Sample orders DataFrame with explicit schema."""
    schema = StructType([
        StructField("order_id",   IntegerType(), nullable=False),
        StructField("customer_id", IntegerType(), nullable=True),
        StructField("amount",      DoubleType(),  nullable=True),
        StructField("status",      StringType(),  nullable=True),
        StructField("order_date",  DateType(),    nullable=True),
    ])
    data = [
        (1001, 1, 250.0,  "COMPLETED", date(2024, 1, 15)),
        (1002, 2, 89.99, "PENDING",   date(2024, 1, 16)),
        (1003, 1, 500.0,  "COMPLETED", date(2024, 1, 17)),
        (1004, 3, None,   "FAILED",    date(2024, 1, 18)),
    ]
    return spark.createDataFrame(data, schema)


@pytest.fixture
def products_df(spark):
    data = [(1, "Laptop", 999.99), (2, "Mouse", 29.99)]
    return spark.createDataFrame(data, ["product_id", "name", "price"])


# A test using multiple fixtures:
def test_orders_are_all_positive(orders_df):
    # Test that all non-null amounts are positive
    negatives = orders_df.filter("amount <= 0")
    assert negatives.count() == 0

def test_completed_orders_count(orders_df):
    completed = orders_df.filter("status = 'COMPLETED'")
    assert completed.count() == 2
25.4 — DataFrame Testing

Testing DataFrames Thoroughly

Learn the four essential DataFrame checks: schema validation, row count validation, data equality testing, and column validation. These form the backbone of any PySpark test.

📐
Schema Validation CRITICAL
Why Validate Schema?

Schema validation ensures your transformation produces the correct columns and data types. A schema mismatch (e.g., a column becoming StringType instead of IntegerType) can corrupt downstream aggregations silently.

python — schema validation patterns
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

def test_output_schema(spark):
    # Define what the schema SHOULD look like
    expected_schema = StructType([
        StructField("customer_id", IntegerType(), nullable=True),
        StructField("full_name",    StringType(),  nullable=True),
        StructField("total_spend",  DoubleType(),  nullable=True),
    ])

    input_df = spark.createDataFrame([
        (1, "Alice", "Smith", 300.0),
        (2, "Bob",   "Jones", 150.5),
    ], ["customer_id", "first_name", "last_name", "total_spend"])

    result = transform_customers(input_df)  # should concat first+last into full_name

    # Method 1: Compare full schemas
    assert result.schema == expected_schema

    # Method 2: Check column names only
    assert result.columns == ["customer_id", "full_name", "total_spend"]

    # Method 3: Check specific column type
    schema_dict = {f.name: f.dataType for f in result.schema.fields}
    assert isinstance(schema_dict["total_spend"], DoubleType)

    # Method 4: Check no extra unexpected columns
    expected_cols = {"customer_id", "full_name", "total_spend"}
    actual_cols = set(result.columns)
    assert actual_cols == expected_cols, f"Unexpected columns: {actual_cols - expected_cols}"
🔢
Row Count Validation BASIC CHECK
Count Checks

Row count validation ensures your transformation doesn't silently drop or duplicate rows. It's the first check to run because if the count is wrong, nothing else matters.

python — row count validation
def test_no_rows_lost_in_join(spark):
    customers = spark.createDataFrame(
        [(1, "Alice"), (2, "Bob"), (3, "Carol")],
        ["cust_id", "name"]
    )
    orders = spark.createDataFrame(
        [(1, 100.0), (1, 200.0), (2, 150.0)],
        ["cust_id", "amount"]
    )
    
    # Left join — Carol has no orders but should still appear
    result = customers.join(orders, on="cust_id", how="left")
    
    # 3 customers × up to 2 orders each = 4 rows (Carol gets 1 null row)
    assert result.count() == 4
    
    # Verify Carol appears with null amount
    carol_rows = result.filter("name = 'Carol'")
    assert carol_rows.count() == 1
    assert carol_rows.collect()[0]["amount"] is None


def test_dedup_reduces_rows(spark):
    df = spark.createDataFrame(
        [(1, "A"), (1, "A"), (2, "B")],
        ["id", "val"]
    )
    result = df.dropDuplicates()
    assert result.count() == 2  # was 3, should be 2 after dedup


def test_filter_result_count(orders_df):
    completed = orders_df.filter("status = 'COMPLETED'")
    
    # Exact count
    assert completed.count() == 2
    
    # Range check (useful when data size isn't fully known)
    total = orders_df.count()
    assert 0 < completed.count() < total  # some, but not all
⚖️
Data Equality Testing DEEP CHECK
Comparing DataFrames Row by Row

Data equality testing checks that every row and column value matches the expected result exactly. DataFrames have no concept of row order, so you must sort before comparing.

⚠️ Important
Never use df1.collect() == df2.collect() without sorting first. Spark DataFrames have no guaranteed row order — the same logical data can appear in different orders on different runs.
python — data equality patterns
from pyspark.sql import functions as F

def assert_df_equals(df1, df2, sort_by):
    """
    Compare two DataFrames for equality.
    - Same schema
    - Same row count
    - Same data (order-independent via sort)
    """
    # Check schemas match
    assert df1.schema == df2.schema, f"Schema mismatch:\n{df1.schema}\nvs\n{df2.schema}"

    # Check row counts match
    count1, count2 = df1.count(), df2.count()
    assert count1 == count2, f"Row count mismatch: {count1} vs {count2}"

    # Sort both and compare row by row
    rows1 = df1.orderBy(sort_by).collect()
    rows2 = df2.orderBy(sort_by).collect()
    assert rows1 == rows2, "DataFrame contents do not match!"


def test_aggregate_result(spark):
    orders = spark.createDataFrame([
        (1, "A", 100.0), (2, "A", 200.0),
        (3, "B", 150.0),
    ], ["order_id", "category", "amount"])

    result = orders.groupBy("category").agg(
        F.sum("amount").alias("total_amount"),
        F.count("*").alias("order_count")
    )

    expected = spark.createDataFrame([
        ("A", 300.0, 2),
        ("B", 150.0, 1),
    ], ["category", "total_amount", "order_count"])

    assert_df_equals(result, expected, sort_by="category")
🔎
Column Validation VALUE CHECKS
Checking Column Values and Distributions

Column validation goes beyond counts — it checks specific column values, null rates, value ranges, and allowed categorical values.

python — column-level assertions
from pyspark.sql import functions as F

def test_column_values(spark):
    df = spark.createDataFrame([
        (1, "COMPLETED", 100.0),
        (2, "PENDING",   50.0),
        (3, "FAILED",    0.0),
    ], ["id", "status", "amount"])

    # 1. Check allowed values (enum check)
    allowed_statuses = {"COMPLETED", "PENDING", "FAILED", "CANCELLED"}
    actual_statuses = {row["status"] for row in df.select("status").distinct().collect()}
    assert actual_statuses.issubset(allowed_statuses), f"Invalid statuses: {actual_statuses - allowed_statuses}"

    # 2. Check no nulls in critical column
    null_count = df.filter(F.col("id").isNull()).count()
    assert null_count == 0, "id column must not have nulls"

    # 3. Check value range
    min_amount, max_amount = df.agg(
        F.min("amount"), F.max("amount")
    ).collect()[0]
    assert min_amount >= 0, "Amount cannot be negative"
    assert max_amount < 1_000_000, "Amount seems unreasonably large"

    # 4. Uniqueness check
    total_rows = df.count()
    distinct_ids = df.select("id").distinct().count()
    assert total_rows == distinct_ids, "id column must be unique"

    # 5. Non-empty string check
    empty_status = df.filter(F.trim(F.col("status")) == "").count()
    assert empty_status == 0, "status must not be empty string"
25.5 — Chispa Library

Chispa — Elegant DataFrame Assertions

Chispa is a lightweight Python library that provides elegant, readable assertion functions specifically for PySpark DataFrames. It handles the boilerplate of schema checking, sorting, and equality comparison.

📦 Install
pip install chispa — Works with PySpark 2.x, 3.x, 4.x
DataFrame Equality — assert_df_equality CORE API
assert_df_equality

The main function in Chispa. It compares two DataFrames and raises a descriptive error message if they differ — showing exactly which rows are different, which is much more helpful than a plain assertion failure.

python — chispa assert_df_equality
from chispa.dataframe_comparer import assert_df_equality
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

def test_with_chispa_basic(spark):
    actual = spark.createDataFrame(
        [(1, "Alice"), (2, "Bob")],
        ["id", "name"]
    )
    expected = spark.createDataFrame(
        [(2, "Bob"), (1, "Alice")],  # different order — chispa handles this
        ["id", "name"]
    )
    
    # ignore_row_order=True — sorts before comparing (recommended for most tests)
    assert_df_equality(actual, expected, ignore_row_order=True)


def test_with_chispa_ignore_nullability(spark):
    """Sometimes schema nullable flag differs but data is the same."""
    actual = spark.createDataFrame(
        [(1, "Alice")],
        StructType([StructField("id", IntegerType(), nullable=True),
                    StructField("name", StringType(), nullable=True)])
    )
    expected = spark.createDataFrame(
        [(1, "Alice")],
        StructType([StructField("id", IntegerType(), nullable=False),   # different
                    StructField("name", StringType(), nullable=False)])  # nullable flag
    )
    
    # ignore_nullable=True skips nullable flag comparison
    assert_df_equality(actual, expected, ignore_nullable=True)


# When there's a mismatch, Chispa shows a clear diff:
# DataFramesNotEqualError:
# DF1 row: Row(id=1, name='Alice')
# DF2 row: Row(id=1, name='ALICE')   ← you see exactly what differs
Approximate Equality FLOATS
Why Approximate Equality?

Floating point arithmetic introduces tiny rounding differences. 100.0 / 3 might produce 33.333333333333336 vs 33.33333333333333. Strict equality fails. Chispa's approximate equality checks values within a tolerance.

python — approximate equality
from chispa.dataframe_comparer import assert_approx_df_equality

def test_average_calculation(spark):
    # Compute an average — floating point result
    df = spark.createDataFrame(
        [(1, 100.0), (2, 200.0), (3, 300.0)],
        ["id", "revenue"]
    )
    
    from pyspark.sql import functions as F
    result = df.agg(F.avg("revenue").alias("avg_revenue"))
    
    expected = spark.createDataFrame(
        [(200.00000001,)],  # slightly off due to floating point
        ["avg_revenue"]
    )
    
    # precision=1e-3 means values within 0.001 of each other pass
    assert_approx_df_equality(result, expected, precision=1e-3)


def test_percentage_columns(spark):
    """Good for testing percentage/ratio columns."""
    result = compute_market_share(spark)  # returns % values
    expected = spark.createDataFrame(
        [("Product A", 33.33), ("Product B", 66.67)],
        ["product", "share_pct"]
    )
    
    assert_approx_df_equality(result, expected,
                              precision=0.01,          # within 0.01%
                              ignore_row_order=True)
📋
Schema Comparison SCHEMA ONLY
assert_schema_equality

Sometimes you only want to validate the schema — not the data — for example after a column rename or type cast operation. Chispa has a dedicated schema assertion function.

python — schema-only assertions
from chispa.schema_comparer import assert_schema_equality
from pyspark.sql.types import *

def test_after_cast(spark):
    df = spark.createDataFrame(
        [("1", "100.5")],
        ["user_id", "amount"]
    )
    
    # Cast string columns to correct types
    result = df.withColumn("user_id", df["user_id"].cast(IntegerType())) \
               .withColumn("amount",  df["amount"].cast(DoubleType()))
    
    expected_schema = StructType([
        StructField("user_id", IntegerType(), nullable=True),
        StructField("amount",  DoubleType(),  nullable=True),
    ])
    
    assert_schema_equality(result.schema, expected_schema)


# Ignore nullable flags if they don't matter for your test:
from chispa.schema_comparer import assert_schema_equality_ignore_nullable

def test_schema_ignoring_nullable(spark, result_df):
    expected_schema = StructType([
        StructField("id",   IntegerType(), nullable=False),
        StructField("name", StringType(),  nullable=False),
    ])
    # Won't fail if actual has nullable=True
    assert_schema_equality_ignore_nullable(result_df.schema, expected_schema)
25.6 — Testcontainers

Testcontainers — Real Services in Tests

Testcontainers is a Python library that spins up real Docker containers (PostgreSQL, Kafka, MySQL) during your tests and tears them down afterward. This lets you test against real databases and brokers without needing a shared dev environment.

📦 Install
pip install testcontainers[postgres] testcontainers[kafka] — Requires Docker to be running.
🐳
Containerized Testing REAL SERVICES
Why Use Testcontainers?

With Testcontainers, you get a real database, real Kafka — not mocks. This means your integration test catches bugs that only appear with real services (SQL dialect differences, serialization issues, etc.).

💡 Analogy
Mocking a database is like rehearsing with a cardboard prop. Testcontainers gives you the real prop — a real working door — for the rehearsal, then throws it away after. The production stage still has a different door, but at least you practiced with a real one.
Test starts Docker container spins up Test runs against real service Container destroyed
🗄️
Database Testing — PostgreSQL JDBC
Testing Spark JDBC Reads/Writes

Use a real PostgreSQL container to test your Spark JDBC read/write code. This catches real issues — type mapping errors, null handling, character encoding — that mocks miss.

python — Spark + PostgreSQL testcontainer
import pytest
from testcontainers.postgres import PostgresContainer
import psycopg2

@pytest.fixture(scope="session")
def postgres_container():
    """Start a real PostgreSQL container for the test session."""
    with PostgresContainer("postgres:15") as postgres:
        yield postgres  # container is running during tests
    # Container auto-destroyed after yield block exits


@pytest.fixture(scope="session")
def seeded_postgres(postgres_container):
    """Create tables and insert test data into the container."""
    conn = psycopg2.connect(postgres_container.get_connection_url())
    cursor = conn.cursor()
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS orders (
            order_id   INT PRIMARY KEY,
            customer   VARCHAR(100),
            amount     NUMERIC(10, 2),
            status     VARCHAR(20)
        )
    """)
    cursor.execute("""
        INSERT INTO orders VALUES
        (1, 'Alice', 250.00, 'COMPLETED'),
        (2, 'Bob',   89.99,  'PENDING'),
        (3, 'Carol', 500.00, 'COMPLETED')
    """)
    conn.commit()
    cursor.close()
    conn.close()
    return postgres_container


def test_spark_reads_postgres(spark, seeded_postgres):
    # Build JDBC URL from the container
    jdbc_url = seeded_postgres.get_connection_url().replace(
        "postgresql", "jdbc:postgresql"
    )
    
    # Read the table using Spark JDBC
    df = spark.read.format("jdbc") \
        .option("url", jdbc_url) \
        .option("dbtable", "orders") \
        .option("user", seeded_postgres.POSTGRES_USER) \
        .option("password", seeded_postgres.POSTGRES_PASSWORD) \
        .option("driver", "org.postgresql.Driver") \
        .load()
    
    # Verify data
    assert df.count() == 3
    completed = df.filter("status = 'COMPLETED'")
    assert completed.count() == 2
    
    total = df.agg({"amount": "sum"}).collect()[0][0]
    assert abs(total - 839.99) < 0.01
📨
Kafka Testing STREAMING
Testing Spark Structured Streaming with Kafka

Testcontainers provides a real Kafka broker in a container. You can produce test messages, run a Spark Streaming query in micro-batch mode, and assert on the output — all locally.

python — Spark + Kafka testcontainer
import pytest, json
from testcontainers.kafka import KafkaContainer
from kafka import KafkaProducer

@pytest.fixture(scope="session")
def kafka_container():
    with KafkaContainer() as kafka:
        yield kafka


def test_streaming_from_kafka(spark, kafka_container, tmp_path):
    bootstrap = kafka_container.get_bootstrap_server()
    topic = "test-orders"

    # Produce 3 test messages to Kafka
    producer = KafkaProducer(
        bootstrap_servers=bootstrap,
        value_serializer=lambda v: json.dumps(v).encode()
    )
    for i in range(3):
        producer.send(topic, {"order_id": i, "amount": i * 100})
    producer.flush()

    # Read from Kafka with Spark Streaming (batch mode)
    df = spark.read \
        .format("kafka") \
        .option("kafka.bootstrap.servers", bootstrap) \
        .option("subscribe", topic) \
        .option("startingOffsets", "earliest") \
        .load()
    
    # Parse JSON values
    from pyspark.sql.types import StructType, StructField, IntegerType
    from pyspark.sql import functions as F
    
    schema = StructType([
        StructField("order_id", IntegerType()),
        StructField("amount", IntegerType()),
    ])
    parsed = df.select(F.from_json(F.col("value").cast("string"), schema).alias("data")) \
               .select("data.*")
    
    assert parsed.count() == 3
    total = parsed.agg({"amount": "sum"}).collect()[0][0]
    assert total == 300  # 0 + 100 + 200
25.7 — Data Quality Testing

Data Quality Frameworks

Data quality testing uses specialized frameworks to validate data at scale. Beyond simple assertions, these tools define reusable "expectations" and "rules" that can be run in CI pipelines and production alike.

🎯
Great Expectations INDUSTRY STANDARD
What is Great Expectations (GX)?

Great Expectations is the most widely-used Python data quality framework. You define expectations (e.g., "column X must not have nulls"), run them as a checkpoint, and get an HTML data quality report. It integrates with Spark natively.

💡 Analogy
Think of Great Expectations as a "contract" for your data. Just like a legal contract specifies what must be delivered, a GX expectation suite specifies what quality the data must meet.
python — Great Expectations with PySpark
import great_expectations as gx
from great_expectations.dataset import SparkDFDataset

def test_orders_quality(spark):
    # Create a DataFrame
    df = spark.createDataFrame([
        (1, "Alice", "COMPLETED", 250.0),
        (2, "Bob",   "PENDING",   89.99),
        (3, "Carol", "COMPLETED", 500.0),
    ], ["order_id", "customer", "status", "amount"])

    # Wrap in GX Spark Dataset
    ge_df = SparkDFDataset(df)

    # Define expectations
    # 1. order_id must not have nulls
    result = ge_df.expect_column_values_to_not_be_null("order_id")
    assert result.success

    # 2. status must be one of the allowed values
    result = ge_df.expect_column_values_to_be_in_set(
        "status", {"COMPLETED", "PENDING", "FAILED", "CANCELLED"}
    )
    assert result.success

    # 3. amount must be positive
    result = ge_df.expect_column_values_to_be_between(
        "amount", min_value=0.01, max_value=100_000
    )
    assert result.success

    # 4. order_id must be unique
    result = ge_df.expect_column_values_to_be_unique("order_id")
    assert result.success

    # 5. Row count expectation
    result = ge_df.expect_table_row_count_to_be_between(min_value=1, max_value=10_000)
    assert result.success

    # 6. Column existence
    result = ge_df.expect_column_to_exist("customer")
    assert result.success
🔍
Deequ (AWS) AWS NATIVE
What is Deequ?

Deequ is an open-source data quality library built by Amazon, implemented on top of Apache Spark. It defines constraints (rules) and analyzers (metrics). PyDeequ is the Python wrapper.

python — PyDeequ constraints
# pip install pydeequ
from pydeequ.checks import Check, CheckLevel
from pydeequ.verification import VerificationSuite, VerificationResult

def test_with_deequ(spark):
    df = spark.createDataFrame([
        (1, "Alice", 28,  "alice@email.com"),
        (2, "Bob",   35,  "bob@email.com"),
        (3, None,   22,  "carol@email.com"),
    ], ["id", "name", "age", "email"])

    check = (Check(spark, CheckLevel.Error, "Order Quality")
        .isComplete("id")                      # id has no nulls
        .isUnique("id")                         # id is unique
        .isNonNegative("age")                   # age >= 0
        .isContainedIn("name", ["Alice", "Bob", "Carol"])  # allowed names
        .hasCompleteness("name", lambda p: p >= 0.8)  # 80%+ non-null
    )

    result = VerificationSuite(spark).onData(df).addChecks([check]).run()
    df_result = VerificationResult.checkResultsAsDataFrame(spark, result)
    
    # Show detailed results
    df_result.show(truncate=False)
    
    # Fail test if any constraint failed
    failed = df_result.filter("constraint_status = 'Failure'")
    assert failed.count() == 0, f"DQ constraints failed!"
🫧
Soda YAML DRIVEN
What is Soda?

Soda defines data quality checks in YAML (SodaCL language) — no Python code needed for checks. Soda Core runs checks against your Spark DataFrames or SQL databases and reports pass/fail.

yaml — SodaCL check file (checks.yml)
# checks.yml
checks for orders:
  # Row count must be between 100 and 1 million
  - row_count:
      between [100, 1000000]

  # order_id must have zero nulls
  - missing_count(order_id) = 0

  # status only allows these values
  - invalid_count(status):
      valid values: [COMPLETED, PENDING, FAILED, CANCELLED]
      must be 0

  # amount must be positive
  - min(amount) > 0

  # customer_email must match regex
  - invalid_count(customer_email):
      valid regex: '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
      must be 0

  # Duplicate check
  - duplicate_count(order_id) = 0
python — Running Soda checks from Python
# pip install soda-core-spark-df
from soda.scan import Scan

def test_soda_quality(spark):
    df = spark.read.parquet("data/orders.parquet")
    
    scan = Scan()
    scan.set_scan_definition_name("orders-test")
    scan.set_data_source_name("spark_df")
    scan.add_spark_session(spark)
    
    # Register the DataFrame under the name "orders"
    scan.add_dataframe_table(df, "orders")
    
    # Add the YAML check file
    scan.add_sodacl_yaml_file("checks/checks.yml")
    
    # Run checks
    scan.execute()
    
    # Assert no failures
    assert scan.get_checks_fail_count() == 0, \
        f"Soda checks failed:\n{scan.get_scan_results()}"
⚙️
Custom Validation Rules BUILD YOUR OWN
Building a Reusable Validation Framework

When GX/Deequ/Soda feel heavy, you can build a lightweight custom validation framework using plain Python and PySpark. Define rules as functions, return a results dictionary, and assert all rules pass.

python — custom validation framework
from dataclasses import dataclass
from typing import List, Callable
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

@dataclass
class ValidationRule:
    name: str
    check: Callable[[DataFrame], bool]
    description: str


class DataValidator:
    def __init__(self, df: DataFrame):
        self.df = df
        self.rules: List[ValidationRule] = []
    
    def add_rule(self, name, check, description=""):
        self.rules.append(ValidationRule(name, check, description))
        return self  # allow chaining
    
    def validate(self):
        results = []
        for rule in self.rules:
            try:
                passed = rule.check(self.df)
                results.append({"rule": rule.name, "passed": passed,
                                 "desc": rule.description, "error": None})
            except Exception as e:
                results.append({"rule": rule.name, "passed": False,
                                 "desc": rule.description, "error": str(e)})
        return results


# Using the custom validator:
def test_custom_validator(spark, orders_df):
    results = (
        DataValidator(orders_df)
        .add_rule(
            "no_null_order_ids",
            lambda df: df.filter(F.col("order_id").isNull()).count() == 0,
            "Order IDs must not be null"
        )
        .add_rule(
            "positive_amounts",
            lambda df: df.filter("amount <= 0").count() == 0,
            "All amounts must be positive"
        )
        .add_rule(
            "valid_statuses",
            lambda df: df.filter(
                ~F.col("status").isin(["COMPLETED", "PENDING", "FAILED"])
            ).count() == 0,
            "Status must be one of allowed values"
        )
        .validate()
    )

    failures = [r for r in results if not r["passed"]]
    assert not failures, f"Validation failed: {failures}"
Module 25 — Knowledge Check

Test Your Understanding

7 questions covering all topics in Module 25. Select an answer to see the explanation.

Q1. What scope should you use for a SparkSession pytest fixture to ensure it is created only once for all tests?
✅ Correct: scope="session" — SparkSession startup takes 2–5 seconds. Using session scope creates it once for the entire pytest run and reuses it across all test files.
Q2. You need to run the same test function with 6 different input values. Which pytest feature is best suited for this?
✅ Correct: @pytest.mark.parametrize — This decorator runs the same test function multiple times with different input values, one test per parameter set.
Q3. Why should you NOT compare two DataFrames with df1.collect() == df2.collect() without sorting first?
✅ Correct: No guaranteed row order — Spark is a distributed system. The same logical data can appear in different physical orders on different runs. Always sort by a key column before comparing.
Q4. Which Chispa function should you use when comparing DataFrames with float columns that have tiny rounding differences?
✅ Correct: assert_approx_df_equality — This function compares float values within a configurable precision tolerance (e.g., precision=1e-3), preventing tests from failing due to floating point rounding differences.
Q5. What is the key advantage of Testcontainers over mocking a database connection?
✅ Correct: Real database — Testcontainers spins up an actual PostgreSQL/Kafka/MySQL container. Real services expose real SQL dialect quirks, type mappings, and serialization behaviors that mocks cannot replicate.
Q6. In Great Expectations, what is a "Checkpoint"?
✅ Correct: Pipeline step — A GX Checkpoint combines a datasource, an expectation suite, and actions (like alerting or failing the pipeline). It's the "run DQ check" step you embed in your ETL pipeline.
Q7. What is the key difference between Unit Testing and Integration Testing in PySpark?
✅ Correct: Isolation vs. composition — Unit tests isolate a single function with tiny inline data. Integration tests verify that multiple functions/steps (read → transform → write) work together correctly end-to-end.
Module 25 — Reference

Testing PySpark — Cheat Sheet

Quick-reference for all key patterns, commands, and code snippets from Module 25.

SparkSession Fixture
@pytest.fixture(scope="session")
def spark():
  s = SparkSession.builder
    .master("local[2]")
    .config("spark.sql.
     shuffle.partitions","2")
    .getOrCreate()
  yield s
  s.stop()
Parametrize Tests
@pytest.mark.parametrize(
  "age, expected", [
    (10, "child"),
    (18, "adult"),
    (65, "senior"),
])
def test_cat(spark, age, expected):
  ...
Chispa Equality
from chispa.dataframe_comparer
  import assert_df_equality

assert_df_equality(
  actual, expected,
  ignore_row_order=True,
  ignore_nullable=True
)
Chispa Approx
from chispa.dataframe_comparer
  import assert_approx_df_equality

assert_approx_df_equality(
  actual, expected,
  precision=1e-3,
  ignore_row_order=True
)
Schema Check
# Column names
assert df.columns == ["a","b","c"]

# Data types
schema_dict = {
  f.name: f.dataType
  for f in df.schema.fields
}
assert isinstance(
  schema_dict["age"], IntegerType)
Column Validation
# No nulls
assert df.filter(col("id")
  .isNull()).count() == 0

# Unique
assert df.count() ==
  df.select("id").distinct().count()

# Value range
assert df.filter("amount<0")
  .count() == 0
Testcontainers Setup
from testcontainers.postgres
  import PostgresContainer

@pytest.fixture(scope="session")
def pg():
  with PostgresContainer("postgres:15")
    as pg:
    yield pg
Great Expectations
ge_df = SparkDFDataset(df)

r = ge_df.expect_column_values
  _to_not_be_null("id")
assert r.success

r = ge_df.expect_column_values
  _to_be_in_set("status",
  {"COMPLETED","PENDING"})
assert r.success
📋
Testing Framework Comparison
ToolTypeBest ForLanguageSpark Native
PytestTest runnerAll Python testingPython✅ Yes
ChispaDataFrame assertionsComparing DataFrames elegantlyPython✅ PySpark only
TestcontainersIntegration testingReal DB/Kafka in testsPython🔶 Via JDBC/Kafka
Great ExpectationsData qualityExpectation suites, DQ reportsPython/YAML✅ Yes
PyDeequData qualityAWS-native constraint checkingPython✅ Built on Spark
SodaData qualityYAML-driven checks, no-code DQYAML/Python✅ Yes
🧪
Test Type Decision Guide
What you're testingTest TypeTool
A single transform functionUnit Testpytest + inline data
Read Parquet → transform → write ParquetIntegration Testpytest + tmp_path
Schema and types after castSchema Checkchispa assert_schema_equality
Float averages/percentagesApprox Equalitychispa assert_approx_df_equality
Read/write from real PostgreSQLContainerizedtestcontainers[postgres]
Data in production passes DQ rulesDQ CheckGreat Expectations / Soda
Same code produces same results as beforeRegressionpytest + golden files
✅ Module 25 Complete!
You've covered all 7 topics: Testing Fundamentals, Pytest, Spark Testing, DataFrame Testing, Chispa, Testcontainers, and Data Quality Testing with Great Expectations, Deequ, and Soda. Ready for Module 26: Logging & Monitoring!