Menu Close

PySpark Accumulator with Example

Pyspark Accumulator

In this today’s article, we are about to see PySpark Accumulator with Example. An Accumulator in PySpark is a shared variable that performs aggregations or counters across worker nodes in a cluster. Accumulators allow you to aggregate values across multiple tasks and retrieve the result to the driver program.

They are write-only on worker nodes (tasks can only add to the accumulator value) and readable only on the driver.

Why Use an Accumulator?

There are various reasons to use PySpark Accumulator.

  • Global Aggregation: Accumulators aggregate data from multiple partitions or tasks into a single value.
  • Counters: Often used for debugging, tracking progress, or counting events like error occurrences.
  • Shared State: They allow the driver program to collect information about tasks executed in parallel.

When Should You Use an Accumulator?

  • When you need to perform global aggregations or maintain counters during distributed processing.
  • When you need to track metrics (e.g., number of invalid rows, failed operations) while processing large datasets in a distributed fashion.

Key Points About Accumulators

These are some important key points about PySpark Accumulator.

Driver Readability

  • The accumulator value is only available to the driver program.

Worker Write Only

  • Worker nodes can only add to the accumulator value but cannot read or modify it.

Fault Tolerance

  • Accumulators are fault-tolerant. If a task is retried, the updates from the failed task may be applied again, so avoid using accumulators for exact calculations.

PySpark Accumulator with Example

Let’s take a real work scenario to count the invalid records from the dataset. I have a small dataset but in real-life applications, you might have a large dataset with thousands or lakhs of records.

from pyspark.sql import SparkSession

# Initialize SparkSession
spark = SparkSession.builder.appName("Accumulator Example").getOrCreate()

# Create an accumulator
invalid_count = spark.sparkContext.accumulator(0)

# Example data
data = [
    "John,25,M",
    "Jane,30,F",
    "InvalidRecord",
    "Alice,,F",
    "Bob,40,M",
]

# Parallelize the data
rdd = spark.sparkContext.parallelize(data)


# Function to process and count invalid records
def process_record(record):
    global invalid_count
    fields = record.split(",")
    if len(fields) != 3 or not fields[1].isdigit():
        invalid_count += 1
        return None
    return record


# Process the RDD
processed_rdd = rdd.map(process_record).filter(lambda x: x is not None)

# Action to trigger processing
processed_rdd.collect()

# Print the result of the accumulator
print(f"Number of invalid records: {invalid_count.value}")

# Stop the Spark session
spark.stop()

After executing the above code the output would be:

Number of invalid records: 2

Explanation of the Code:

  • Creating the spark session to be the feature of spark.
  • Creating an accumulator variable by using invalid_count = spark.sparkContext.accumulator(0) and initializing an accumulator with an initial value of 0.
  • Now each record of data will be processed, and if it doesn’t meet the criteria, the invalid_count is incremented.
  • The collect action triggers the transformation, causing the accumulator to aggregate values from tasks.
  • After the action is completed, the driver can access the final value of the accumulator with invalid_count.value.

This is how you can implement PySpark Accumulator.

Read Also:

Conclusion

I hope you found this PySpark Accumulator tutorial helpful, PySpark Accumulators are powerful tools for aggregating metrics and tracking global statistics during distributed computations.

Remember, Only the worker node can add the value to the accumulator variable and that value can be read by the driver program of the application.

PySpark Accumulator Docs:- Click Here

Thanks for your time.

PySpark Broadcast Variables Tutorial

Related Posts