Optimising Machine Learning inference with PySpark and Pandas UDFs

Introduction to parallel processing in Machine Learning

  • Article
  • Data Engineering
  • Machine learning operations
machine learning

In the world of machine learning, working with large datasets and complex models can quickly become time-consuming and resource-intensive. To speed up this process, parallelisation becomes crucial. This technique involves breaking down tasks into smaller subtasks that can be processed simultaneously across multiple CPU cores or distributed machines within a cluster. By spreading out the workload, you can achieve faster and more efficient data processing on a large scale.

Imagine you have a large set of data points that need to be processed by a machine learning model. Rather than processing each point one after another (sequentially), you divide the data into smaller chunks. Each chunk is then processed at the same time (in parallel), either on different CPU cores or across multiple machines. The result is that predictions on large datasets can be made significantly faster by distributing the computation.

In this blog, we will explore various methods for parallelising machine learning inference, focusing primarily on Python within the Databricks environment. Specifically, we will examine strategies for transcribing large volumes of audio files using OpenAI's Whisper model. While these techniques are tailored for Databricks, their principles can be applied more broadly. Throughout this article, you will find practical code snippets to illustrate how these methods work.

Method 1: Using Python's multiprocessing package

Overview of Python’s multiprocessing

Python's multiprocessing package offers a way to run multiple processes simultaneously, leveraging multi-core processors to speed up computation. This method is suitable for a single machine with multiple cores and is a good starting point for parallelism.

Implementation details

import multiprocessing
from faster_whisper import WhisperModel

model = WhisperModel('large-v3', compute_type="float16")

def transcribe_call(file_name):
    transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
    return {file_name: transcription}

def transcribe_audio_files(audio_files, num_workers=4):
    with multiprocessing.Pool(num_workers) as pool:
        results = pool.map(transcribe_call, audio_files)

    return results

audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']

num_workers = multiprocessing.cpu_count()

transcriptions = transcribe_audio_files(audio_files, num_workers)

Advantages

  • Leverages multi-core processors to distribute work and reduce processing time.
  • Simple to implement on smaller-scale systems that don’t require a distributed framework.

Disadvantages

  • Overhead: Spawning and managing multiple processes introduces overhead, which might negate the performance benefits for smaller datasets.
  • Memory Consumption: Each process requires its own memory space, which can quickly add up, especially with large models such as Whisper.

Method 2: Using RDDs and broadcasting in PySpark

What is PySpark?

PySpark, the Python API for Apache Spark, provides powerful tools for distributed computing. One of the fundamental components of PySpark is the Resilient Distributed Dataset (RDD). RDDs are immutable distributed collections of objects that can be processed in parallel. When working with large machine learning models, RDDs can be used to parallelise predictions across a cluster

Broadcasting in spark

Broadcasting in Spark involves distributing a read-only variable (like a machine learning model) to all worker nodes. This ensures that the variable is efficiently available to each node without repeatedly sending it over the network. In this approach, the variable is cached in memory.

Using flatMap transformation

The flatMap transformation in PySpark allows you to apply a function that returns an iterable to each element of an RDD and then flattens the results. This can be used to perform predictions on batches of data within each RDD partition.

Implementation details

  1. Loading the model: The model is loaded on the driver and then broadcasted to all worker nodes.
  2. Broadcasting: The broadcasted model ensures that each worker node has access to the model without needing to reload it.
  3. Predicting with flatMap: The flatMap transformation is applied to the RDD to perform predictions on the data.
from pyspark import SparkContext
from faster_whisper import WhisperModel

# Initialize the Spark context
sc = SparkContext.getOrCreate()

model = WhisperModel('large-v3', device="cuda", compute_type="float16")
broadcast_model = sc.broadcast(model)

def transcribe_call(file_name):

    model = broadcast_model.value
    transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
    return {file_name: transcription}

audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']

audio_rdd = sc.parallelize(audio_files)
transcriptions_rdd = audio_rdd.flatMap(lambda file_name: [transcribe_call(file_name)])

transcriptions = transcriptions_rdd.collect()

Advantages

  • Better for complex operations: For complex transformations that are difficult to express in SQL or DataFrames, RDDs (and flatMap) provide a more straightforward API.
  • Lazy Evaluation: Operations on RDDs are lazily evaluated, meaning computations are only executed when necessary, potentially optimising resource usage.

Disadvantages

  • Driver Memory Limitation: Broadcasting large models such as Whisper can lead to driver memory exhaustion as the driver needs to keep a copy of the broadcasted variable.
  • Complexity: Managing RDDs and transformations can be more complex compared to DataFrames. Especially with debugging of the operations if something went wrong.

Method 3: Using PySpark and Pandas UDFs

Overview of Pandas UDFs

While RDDs and broadcasting offer a way to parallelise predictions, PySpark also provides a more modern and efficient approach using DataFrames and User-Defined Functions (UDFs). UDFs allow you to define custom functions in Python and apply them to Spark DataFrames

Traditional UDFs in PySpark can be slow due to the serialisation and deserialisation overhead between Spark and Python. Pandas UDFs, also known as vectorised UDFs, overcome this limitation by processing batches of data at once and leveraging the performance of Apache Arrow for efficient data exchange.Pandas UDFs allow you to apply functions to Spark DataFrames in a highly efficient manner. They process batches of data at once, significantly reducing the overhead compared to row-by-row UDFs.

Implementing batch predictions with Pandas UDFs

To address the challenge of running large machine learning models in parallel, we can use the predict_batch_udf function within PySpark, with an underlying Pandas UDF. This approach ensures that a large model is only loaded once on a single worker, and the actual prediction function is cached in memory and can be used repeatedly.

Implementation details

from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import StringType, Row


def transcribe_parallel():
    from faster_whisper import WhisperModel
    import json
    import numpy as np
    model = WhisperModel('large-v3', device="cuda", compute_type="float16")

    def transcribe_call(file_name):
        file_name = file_name[0]
        transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)

        return np.array([json.dumps(transcription, ensure_ascii=False)])

    return transcribe_call


transcribe_and_return_json = predict_batch_udf(transcribe_parallel,
                                               return_type=StringType(), batch_size=1)

audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']
list_of_rows = [Row(filename=file) for file in audio_files]
files_df = spark.createDataFrame(list_of_rows)
transcription_df = files_df.withColumn('output_transcription', transcribe_and_return_json('filename'))

We define a function transcribe_parallel which initialises a Whisper model only once per worker. This function, transcribe_parallel, contains an inner function transcribe_call that transcribes an audio file by using the Whisper model.

By using predict_batch_udf, we transform transcribe_parallel into a user-defined function (UDF) that Spark can apply to large datasets in parallel. The UDF processes data in batches, specified to be 1 audio file per batch. If a model can handle multiple audio_files at the same time, the batch size can be adjusted.
Finally, we apply this UDF to a DataFrame, adding a new column output_transcription that contains the transcription for each audio file in the ‘filename’ column.

Advantages

  • Efficiency: By loading the model once per worker and caching the predict function, this approach minimises both time and memory consumption.
  • Scalability: PySpark's distributed nature allows for easy scaling with autoscaling workers as needed.
  • Simplicity: Using Pandas UDFs simplifies the code and leverages Spark's optimised execution engine.

Disadvantages

  • JVM and Pyspark overhead: Spark relies on the JVM for executing core Spark tasks, but when Python code (like Pandas UDFs) is invoked, data must be transferred back and forth between the JVM and Python. The serialisation and deserialisation of between the Python environment and JVM can be time consuming. However, this overhead counts for all non-native Spark Dataframe operations.

Conclusion

Using PySpark'spredict_batch_udf function to run batch predictions with a large machine learning model offers significant advantages. By loading the model once per worker and caching the predict function, this approach minimises both time and memory consumption. Additionally, Spark's distributed nature allows for easy scaling with autoscaling workers. Therefore, PySpark’spredict_batch_udf functionality provides a more efficient and scalable solution for large-scale model inference, making it a valuable tool for data scientists working with extensive datasets.

This is an article by Cas Hortensius, Data Engineer at Digital Power

Cas is a data professional with a passion for bridging the gap between raw data and actionable insights for analytics teams. Leveraging his engineering background, he ensures data is accessible and tailored for analysis across sectors such as the music industry, finance, and energy industry.

Cas Hortensius

Receive data insights, use cases and behind-the-scenes peeks once a month?

Sign up for our email list and stay 'up to data':

You might also like

Your Data Engineering partner

Generate reliable and meaningful insights from a solid, secure and scalable infrastructure. Our team of 25+ Data Engineers is ready to implement, maintain and optimise your data products and infrastructure end-to-end.

Read more

Comparing the best Python project managers

In the ever-changing world of Python, managing packages, environments and versions efficiently is important. Traditional tools like pip and conda have served us well, but as projects become more complex, so do our requirements. This guide looks at modern alternatives - Poetry, PDM, Hatch and Rye - each of which offers unique capabilities to streamline Python project management.

Read more

What does a (Cloud) Data Engineer do versus a Machine Learning Engineer?

In the world of data and technology, Data Engineers and Machine Learning Engineers are crucial players. Both roles are essential for designing, building, and maintaining modern data infrastructures and advanced machine learning (ML) applications. In this blog, we focus specifically on the roles and responsibilities of a Data Engineer and Machine Learning Engineer.

Read more

How does the AI Document Explorer work in practice?

The AI Document Explorer (AIDE) is a cloud solution developed by Digital Power that utilises OpenAI's GPT model. It can be deployed to quickly gain insights into company documents. AIDE securely indexes your files, enabling you to ask questions about your own documents. Not only does it provide you with the answers you are looking for, but it also references the locations where these answers are found.

Read more

Fast and reliable internal information using AI Document Explorer

Financial institutions need to process large amounts of documentation. For this particular institution, an internal team facilitates this by, for example, creating summaries using text analysis and natural language processing (NLP). They make these available to the various business units. To conduct audits more efficiently, they wanted to develop a question-and-answer model to get the right information to them faster. When ChatGPT was launched, they asked us to create a proof of concept.

Read more

Implementing a data platform

Based on our know-how, the purpose of this blog is to transmit our knowledge and experience to the community by describing guidelines for implementing a data platform in an organisation. We understand that the specific needs of every organisation are different, that they will have an impact on the technologies used and that a single architecture satisfying all of them makes no sense. So, in this blog we will keep it as general as we can.

Read more

Working more efficiently thanks to migration to Databricks

The Kadaster manages complex (geo)data, including all real estate in the Netherlands. All data is stored and processed using an on-premise data warehouse in Postgres. They rely on an IT partner for maintaining this warehouse. The Kadaster aims to save costs and work more efficiently by migrating to a Databricks environment. They asked us to assist in implementing this data lakehouse in the Microsoft Azure Cloud.

Read more

Converting billions of streams into actionable insights with a new data & analytics platform

Merlin is the largest digital music licensing partner for independent labels, distributors, and other rightsholders. Merlin’s members represent 15% of the global recorded music market. The company has deals in place with Apple, Facebook, Spotify, YouTube, and 40 other innovative digital platforms around the world for its’ member’s recordings. The Merlin team tracks payments and usage reports from digital partners while ensuring that their members are paid and reported to accurately, efficiently, and consistently.

Read more

20% fewer complaints thanks to data-driven maintenance reports

An essential part of Otis's business operations is the maintenance of their elevators. To time this effectively and proactively inform customers about the status of their elevator, Otis wanted to implement continuous monitoring. They saw great potential in predictive maintenance and remote maintenance.

Read more

Kubernetes-based event-driven autoscaling with KEDA: a practical guide

This article explains the essence of Kubernetes Event Driven Autoscaling (KEDA). Subsequently, we configure a local development environment enabling the demonstration of KEDA using Docker and Minikube. Following this, we expound upon the scenario that will be implemented to showcase KEDA, and we guide through each step of this scenario. By the end of the article, you will have a clear understanding of what KEDA entails and how they can personally implement an architecture with KEDA.

Read more

Setting up Azure App functions

In the article, we start by discussing Serverless Functions. Then we demonstrate how to use Terraform files to simplify the process of deploying a target infrastructure, how to create a Function App in Azure, the use GitHub workflows to manage continuous integration and deployment, and how to use branching strategies to selectively deploy code changes to specific instances of Function Apps.

Read more

Setting up a future-proof data infrastructure

Valk Exclusief is a chain of 4-star+ hotels with 43 hotels in the Netherlands. The hotel chain wants to offer guests a personal experience, both in the hotel and online.

Read more

A day in the life of a Data Engineer

For developing modern data applications, a Data Engineer is essential. But what does it actually mean to be a Data Engineer and what exactly do you do? Our colleague Oskar, Data Engineer at Digital Power, explains.

Read more

A well-organised data infrastructure

FysioHolland is an umbrella organisation for physiotherapists in the Netherlands. A central service team relieves therapists of additional work, so that they can mainly focus on providing the best care. In addition to organic growth, FysioHolland is connecting new practices to the organisation. Each of these has its own systems, work processes and treatment codes. This has made FysioHolland's data management large and complex.

Read more

A scalable machine-learning platform for predicting billboard impressions

The Neuron provides a programmatic bidding platform to plan, buy and manage digital Out-Of-Home ads in real-time. They asked us to predict the number of expected impressions for digital advertising on billboards in a scalable and efficient way.

Read more