In this today’s article, we are about to see PySpark Accumulator with Example. An Accumulator in PySpark is a shared variable that performs aggregations or counters across worker nodes in a cluster. Accumulators allow you to aggregate values across multiple tasks and retrieve the result to the driver program.
They are write-only on worker nodes (tasks can only add to the accumulator value) and readable only on the driver.
Headings of Contents
Why Use an Accumulator?
There are various reasons to use PySpark Accumulator.
- Global Aggregation: Accumulators aggregate data from multiple partitions or tasks into a single value.
- Counters: Often used for debugging, tracking progress, or counting events like error occurrences.
- Shared State: They allow the driver program to collect information about tasks executed in parallel.
When Should You Use an Accumulator?
- When you need to perform global aggregations or maintain counters during distributed processing.
- When you need to track metrics (e.g., number of invalid rows, failed operations) while processing large datasets in a distributed fashion.
Key Points About Accumulators
These are some important key points about PySpark Accumulator.
Driver Readability
- The accumulator value is only available to the driver program.
Worker Write Only
- Worker nodes can only add to the accumulator value but cannot read or modify it.
Fault Tolerance
- Accumulators are fault-tolerant. If a task is retried, the updates from the failed task may be applied again, so avoid using accumulators for exact calculations.
PySpark Accumulator with Example
Let’s take a real work scenario to count the invalid records from the dataset. I have a small dataset but in real-life applications, you might have a large dataset with thousands or lakhs of records.
from pyspark.sql import SparkSession # Initialize SparkSession spark = SparkSession.builder.appName("Accumulator Example").getOrCreate() # Create an accumulator invalid_count = spark.sparkContext.accumulator(0) # Example data data = [ "John,25,M", "Jane,30,F", "InvalidRecord", "Alice,,F", "Bob,40,M", ] # Parallelize the data rdd = spark.sparkContext.parallelize(data) # Function to process and count invalid records def process_record(record): global invalid_count fields = record.split(",") if len(fields) != 3 or not fields[1].isdigit(): invalid_count += 1 return None return record # Process the RDD processed_rdd = rdd.map(process_record).filter(lambda x: x is not None) # Action to trigger processing processed_rdd.collect() # Print the result of the accumulator print(f"Number of invalid records: {invalid_count.value}") # Stop the Spark session spark.stop()
After executing the above code the output would be:
Number of invalid records: 2
Explanation of the Code:
- Creating the spark session to be the feature of spark.
- Creating an accumulator variable by using
invalid_count = spark.sparkContext.accumulator(0)
and initializing an accumulator with an initial value of 0. - Now each record of data will be processed, and if it doesn’t meet the criteria, the
invalid_count
is incremented. - The
collect
action triggers the transformation, causing the accumulator to aggregate values from tasks. - After the action is completed, the driver can access the final value of the accumulator with
invalid_count.value
.
This is how you can implement PySpark Accumulator.
Read Also:
- PySpark Broadcast Variables Tutorial
- Top 30 PySpark DataFrame Methods with Example
- Partitions in PySpark: A Comprehensive Guide with Examples
- How to Create Temp View in PySpark
Conclusion
I hope you found this PySpark Accumulator tutorial helpful, PySpark Accumulators are powerful tools for aggregating metrics and tracking global statistics during distributed computations.
Remember, Only the worker node can add the value to the accumulator variable and that value can be read by the driver program of the application.
PySpark Accumulator Docs:- Click Here
Thanks for your time.