Machine Learning-inferentie optimaliseren met PySpark en Pandas UDF's
Introductie tot parallelle verwerking in Machine Learning
- Artikel
- Data Engineering
- Machine learning operations


In de wereld van machine learning kan het werken met grote datasets en complexe modellen al snel tijdrovend en resource-intensief worden. Om dit proces te versnellen is parallellisatie cruciaal. Deze techniek bestaat uit het opsplitsen van taken in kleinere subtaken die gelijktijdig verwerkt kunnen worden op meerdere CPU cores of gedistribueerde machines binnen een cluster. Door de werklast te spreiden, kun je data sneller en efficiënter verwerken op grote schaal.
*Het artikel is geschreven in het Engels voor een betere leesbaarheid.
Imagine you have a large set of data points that need to be processed by a machine learning model. Rather than processing each point one after another (sequentially), you divide the data into smaller chunks. Each chunk is then processed at the same time (in parallel), either on different CPU cores or across multiple machines. The result is that predictions on large datasets can be made significantly faster by distributing the computation.
In this blog, we will explore various methods for parallelising machine learning inference, focusing primarily on Python within the Databricks environment. Specifically, we will examine strategies for transcribing large volumes of audio files using OpenAI's Whisper model. While these techniques are tailored for Databricks, their principles can be applied more broadly. Throughout this article, you will find practical code snippets to illustrate how these methods work.
Method 1: Using Python's multiprocessing package
Overview of Python’s multiprocessing
Python's multiprocessing package offers a way to run multiple processes simultaneously, leveraging multi-core processors to speed up computation. This method is suitable for a single machine with multiple cores and is a good starting point for parallelism.
Implementation details
import multiprocessing
from faster_whisper import WhisperModel
model = WhisperModel('large-v3', compute_type="float16")
def transcribe_call(file_name):
transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
return {file_name: transcription}
def transcribe_audio_files(audio_files, num_workers=4):
with multiprocessing.Pool(num_workers) as pool:
results = pool.map(transcribe_call, audio_files)
return results
audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']
num_workers = multiprocessing.cpu_count()
transcriptions = transcribe_audio_files(audio_files, num_workers)Advantages
- Leverages multi-core processors to distribute work and reduce processing time.
- Simple to implement on smaller-scale systems that don’t require a distributed framework.
Disadvantages
- Overhead: Spawning and managing multiple processes introduces overhead, which might negate the performance benefits for smaller datasets.
- Memory Consumption: Each process requires its own memory space, which can quickly add up, especially with large models such as Whisper.
Method 2: Using RDDs and broadcasting in PySpark
What is PySpark?
PySpark, the Python API for Apache Spark, provides powerful tools for distributed computing. One of the fundamental components of PySpark is the Resilient Distributed Dataset (RDD). RDDs are immutable distributed collections of objects that can be processed in parallel. When working with large machine learning models, RDDs can be used to parallelise predictions across a cluster
Broadcasting in spark
Broadcasting in Spark involves distributing a read-only variable (like a machine learning model) to all worker nodes. This ensures that the variable is efficiently available to each node without repeatedly sending it over the network. In this approach, the variable is cached in memory.
Using flatMap transformation
The flatMap transformation in PySpark allows you to apply a function that returns an iterable to each element of an RDD and then flattens the results. This can be used to perform predictions on batches of data within each RDD partition.
Implementation details
- Loading the model: The model is loaded on the driver and then broadcasted to all worker nodes.
- Broadcasting: The broadcasted model ensures that each worker node has access to the model without needing to reload it.
- Predicting with flatMap: The flatMap transformation is applied to the RDD to perform predictions on the data.
from pyspark import SparkContext
from faster_whisper import WhisperModel
# Initialize the Spark context
sc = SparkContext.getOrCreate()
model = WhisperModel('large-v3', device="cuda", compute_type="float16")
broadcast_model = sc.broadcast(model)
def transcribe_call(file_name):
model = broadcast_model.value
transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
return {file_name: transcription}
audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']
audio_rdd = sc.parallelize(audio_files)
transcriptions_rdd = audio_rdd.flatMap(lambda file_name: [transcribe_call(file_name)])
transcriptions = transcriptions_rdd.collect()Advantages
- Better for complex operations: For complex transformations that are difficult to express in SQL or DataFrames, RDDs (and flatMap) provide a more straightforward API.
- Lazy Evaluation: Operations on RDDs are lazily evaluated, meaning computations are only executed when necessary, potentially optimising resource usage.
Disadvantages
- Driver Memory Limitation: Broadcasting large models such as Whisper can lead to driver memory exhaustion as the driver needs to keep a copy of the broadcasted variable.
- Complexity: Managing RDDs and transformations can be more complex compared to DataFrames. Especially with debugging of the operations if something went wrong.
Method 3: Using PySpark and Pandas UDFs
Overview of Pandas UDFs
While RDDs and broadcasting offer a way to parallelise predictions, PySpark also provides a more modern and efficient approach using DataFrames and User-Defined Functions (UDFs). UDFs allow you to define custom functions in Python and apply them to Spark DataFrames
Traditional UDFs in PySpark can be slow due to the serialisation and deserialisation overhead between Spark and Python. Pandas UDFs, also known as vectorised UDFs, overcome this limitation by processing batches of data at once and leveraging the performance of Apache Arrow for efficient data exchange.Pandas UDFs allow you to apply functions to Spark DataFrames in a highly efficient manner. They process batches of data at once, significantly reducing the overhead compared to row-by-row UDFs.
Implementing batch predictions with Pandas UDFs
To address the challenge of running large machine learning models in parallel, we can use the predict_batch_udf function within PySpark, with an underlying Pandas UDF. This approach ensures that a large model is only loaded once on a single worker, and the actual prediction function is cached in memory and can be used repeatedly.
Implementation details
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import StringType, Row
def transcribe_parallel():
from faster_whisper import WhisperModel
import json
import numpy as np
model = WhisperModel('large-v3', device="cuda", compute_type="float16")
def transcribe_call(file_name):
file_name = file_name[0]
transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
return np.array([json.dumps(transcription, ensure_ascii=False)])
return transcribe_call
transcribe_and_return_json = predict_batch_udf(transcribe_parallel,
return_type=StringType(), batch_size=1)
audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']
list_of_rows = [Row(filename=file) for file in audio_files]
files_df = spark.createDataFrame(list_of_rows)
transcription_df = files_df.withColumn('output_transcription', transcribe_and_return_json('filename'))We define a function transcribe_parallel which initialises a Whisper model only once per worker. This function, transcribe_parallel, contains an inner function transcribe_call that transcribes an audio file by using the Whisper model.
By using predict_batch_udf, we transform transcribe_parallel into a user-defined function (UDF) that Spark can apply to large datasets in parallel. The UDF processes data in batches, specified to be 1 audio file per batch. If a model can handle multiple audio_files at the same time, the batch size can be adjusted.
Finally, we apply this UDF to a DataFrame, adding a new column output_transcription that contains the transcription for each audio file in the ‘filename’ column.
Advantages
- Efficiency: By loading the model once per worker and caching the predict function, this approach minimises both time and memory consumption.
- Scalability: PySpark's distributed nature allows for easy scaling with autoscaling workers as needed.
- Simplicity: Using Pandas UDFs simplifies the code and leverages Spark's optimised execution engine.
Disadvantages
- JVM and Pyspark overhead: Spark relies on the JVM for executing core Spark tasks, but when Python code (like Pandas UDFs) is invoked, data must be transferred back and forth between the JVM and Python. The serialisation and deserialisation of between the Python environment and JVM can be time consuming. However, this overhead counts for all non-native Spark Dataframe operations.
Conclusion
Using PySpark'spredict_batch_udf function to run batch predictions with a large machine learning model offers significant advantages. By loading the model once per worker and caching the predict function, this approach minimises both time and memory consumption. Additionally, Spark's distributed nature allows for easy scaling with autoscaling workers. Therefore, PySpark’spredict_batch_udf functionality provides a more efficient and scalable solution for large-scale model inference, making it a valuable tool for data scientists working with extensive datasets.
This is an article by Cas Hortensius, Data Engineer at Digital Power
Cas is a data professional with a passion for bridging the gap between raw data and actionable insights for analytics teams. Leveraging his engineering background, he ensures data is accessible and tailored for analysis across sectors such as the music industry, finance, and energy industry.
1x per maand data insights, praktijkcases en een kijkje achter de schermen ontvangen?
Meld je aan voor onze maillijst en blijf 'up to data':