Machine Learning-inferentie optimaliseren met PySpark en Pandas UDF's

Introductie tot parallelle verwerking in Machine Learning

  • Artikel
  • Data Engineering
  • Machine learning operations
machine learning

In de wereld van machine learning kan het werken met grote datasets en complexe modellen al snel tijdrovend en resource-intensief worden. Om dit proces te versnellen is parallellisatie cruciaal. Deze techniek bestaat uit het opsplitsen van taken in kleinere subtaken die gelijktijdig verwerkt kunnen worden op meerdere CPU cores of gedistribueerde machines binnen een cluster. Door de werklast te spreiden, kun je data sneller en efficiënter verwerken op grote schaal.

*Het artikel is geschreven in het Engels voor een betere leesbaarheid.

Imagine you have a large set of data points that need to be processed by a machine learning model. Rather than processing each point one after another (sequentially), you divide the data into smaller chunks. Each chunk is then processed at the same time (in parallel), either on different CPU cores or across multiple machines. The result is that predictions on large datasets can be made significantly faster by distributing the computation.

In this blog, we will explore various methods for parallelising machine learning inference, focusing primarily on Python within the Databricks environment. Specifically, we will examine strategies for transcribing large volumes of audio files using OpenAI's Whisper model. While these techniques are tailored for Databricks, their principles can be applied more broadly. Throughout this article, you will find practical code snippets to illustrate how these methods work.

Method 1: Using Python's multiprocessing package

Overview of Python’s multiprocessing

Python's multiprocessing package offers a way to run multiple processes simultaneously, leveraging multi-core processors to speed up computation. This method is suitable for a single machine with multiple cores and is a good starting point for parallelism.

Implementation details

import multiprocessing
from faster_whisper import WhisperModel

model = WhisperModel('large-v3', compute_type="float16")

def transcribe_call(file_name):
    transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
    return {file_name: transcription}

def transcribe_audio_files(audio_files, num_workers=4):
    with multiprocessing.Pool(num_workers) as pool:
        results = pool.map(transcribe_call, audio_files)

    return results

audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']

num_workers = multiprocessing.cpu_count()

transcriptions = transcribe_audio_files(audio_files, num_workers)

Advantages

  • Leverages multi-core processors to distribute work and reduce processing time.
  • Simple to implement on smaller-scale systems that don’t require a distributed framework.

Disadvantages

  • Overhead: Spawning and managing multiple processes introduces overhead, which might negate the performance benefits for smaller datasets.
  • Memory Consumption: Each process requires its own memory space, which can quickly add up, especially with large models such as Whisper.

Method 2: Using RDDs and broadcasting in PySpark

What is PySpark?

PySpark, the Python API for Apache Spark, provides powerful tools for distributed computing. One of the fundamental components of PySpark is the Resilient Distributed Dataset (RDD). RDDs are immutable distributed collections of objects that can be processed in parallel. When working with large machine learning models, RDDs can be used to parallelise predictions across a cluster

Broadcasting in spark

Broadcasting in Spark involves distributing a read-only variable (like a machine learning model) to all worker nodes. This ensures that the variable is efficiently available to each node without repeatedly sending it over the network. In this approach, the variable is cached in memory.

Using flatMap transformation

The flatMap transformation in PySpark allows you to apply a function that returns an iterable to each element of an RDD and then flattens the results. This can be used to perform predictions on batches of data within each RDD partition.

Implementation details

  1. Loading the model: The model is loaded on the driver and then broadcasted to all worker nodes.
  2. Broadcasting: The broadcasted model ensures that each worker node has access to the model without needing to reload it.
  3. Predicting with flatMap: The flatMap transformation is applied to the RDD to perform predictions on the data.
from pyspark import SparkContext
from faster_whisper import WhisperModel

# Initialize the Spark context
sc = SparkContext.getOrCreate()

model = WhisperModel('large-v3', device="cuda", compute_type="float16")
broadcast_model = sc.broadcast(model)

def transcribe_call(file_name):

    model = broadcast_model.value
    transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)
    return {file_name: transcription}

audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']

audio_rdd = sc.parallelize(audio_files)
transcriptions_rdd = audio_rdd.flatMap(lambda file_name: [transcribe_call(file_name)])

transcriptions = transcriptions_rdd.collect()

Advantages

  • Better for complex operations: For complex transformations that are difficult to express in SQL or DataFrames, RDDs (and flatMap) provide a more straightforward API.
  • Lazy Evaluation: Operations on RDDs are lazily evaluated, meaning computations are only executed when necessary, potentially optimising resource usage.

Disadvantages

  • Driver Memory Limitation: Broadcasting large models such as Whisper can lead to driver memory exhaustion as the driver needs to keep a copy of the broadcasted variable.
  • Complexity: Managing RDDs and transformations can be more complex compared to DataFrames. Especially with debugging of the operations if something went wrong.

Method 3: Using PySpark and Pandas UDFs

Overview of Pandas UDFs

While RDDs and broadcasting offer a way to parallelise predictions, PySpark also provides a more modern and efficient approach using DataFrames and User-Defined Functions (UDFs). UDFs allow you to define custom functions in Python and apply them to Spark DataFrames

Traditional UDFs in PySpark can be slow due to the serialisation and deserialisation overhead between Spark and Python. Pandas UDFs, also known as vectorised UDFs, overcome this limitation by processing batches of data at once and leveraging the performance of Apache Arrow for efficient data exchange.Pandas UDFs allow you to apply functions to Spark DataFrames in a highly efficient manner. They process batches of data at once, significantly reducing the overhead compared to row-by-row UDFs.

Implementing batch predictions with Pandas UDFs

To address the challenge of running large machine learning models in parallel, we can use the predict_batch_udf function within PySpark, with an underlying Pandas UDF. This approach ensures that a large model is only loaded once on a single worker, and the actual prediction function is cached in memory and can be used repeatedly.

Implementation details

from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import StringType, Row


def transcribe_parallel():
    from faster_whisper import WhisperModel
    import json
    import numpy as np
    model = WhisperModel('large-v3', device="cuda", compute_type="float16")

    def transcribe_call(file_name):
        file_name = file_name[0]
        transcription = model.transcribe(file_name, word_timestamps=True, vad_filter=True)

        return np.array([json.dumps(transcription, ensure_ascii=False)])

    return transcribe_call


transcribe_and_return_json = predict_batch_udf(transcribe_parallel,
                                               return_type=StringType(), batch_size=1)

audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']
list_of_rows = [Row(filename=file) for file in audio_files]
files_df = spark.createDataFrame(list_of_rows)
transcription_df = files_df.withColumn('output_transcription', transcribe_and_return_json('filename'))

We define a function transcribe_parallel which initialises a Whisper model only once per worker. This function, transcribe_parallel, contains an inner function transcribe_call that transcribes an audio file by using the Whisper model.

By using predict_batch_udf, we transform transcribe_parallel into a user-defined function (UDF) that Spark can apply to large datasets in parallel. The UDF processes data in batches, specified to be 1 audio file per batch. If a model can handle multiple audio_files at the same time, the batch size can be adjusted.
Finally, we apply this UDF to a DataFrame, adding a new column output_transcription that contains the transcription for each audio file in the ‘filename’ column.

Advantages

  • Efficiency: By loading the model once per worker and caching the predict function, this approach minimises both time and memory consumption.
  • Scalability: PySpark's distributed nature allows for easy scaling with autoscaling workers as needed.
  • Simplicity: Using Pandas UDFs simplifies the code and leverages Spark's optimised execution engine.

Disadvantages

  • JVM and Pyspark overhead: Spark relies on the JVM for executing core Spark tasks, but when Python code (like Pandas UDFs) is invoked, data must be transferred back and forth between the JVM and Python. The serialisation and deserialisation of between the Python environment and JVM can be time consuming. However, this overhead counts for all non-native Spark Dataframe operations.

Conclusion

Using PySpark'spredict_batch_udf function to run batch predictions with a large machine learning model offers significant advantages. By loading the model once per worker and caching the predict function, this approach minimises both time and memory consumption. Additionally, Spark's distributed nature allows for easy scaling with autoscaling workers. Therefore, PySpark’spredict_batch_udf functionality provides a more efficient and scalable solution for large-scale model inference, making it a valuable tool for data scientists working with extensive datasets.

This is an article by Cas Hortensius, Data Engineer at Digital Power

Cas is a data professional with a passion for bridging the gap between raw data and actionable insights for analytics teams. Leveraging his engineering background, he ensures data is accessible and tailored for analysis across sectors such as the music industry, finance, and energy industry.

Cas Hortensius

1x per maand data insights, praktijkcases en een kijkje achter de schermen ontvangen?

Meld je aan voor onze maillijst en blijf 'up to data':

Dit vind je misschien ook interessant

Jouw Data Engineering partner

Genereer betrouwbare en betekenisvolle inzichten uit een solide, veilige en schaalbare infrastructuur. Ons team van 25+ Data Engineers staat klaar om jouw dataproducten en -infrastructuur end-to-end te implementeren, te onderhouden én te optimaliseren.

Lees meer

De beste Python-projectmanagers vergelijken

In de steeds veranderende wereld van Python is het belangrijk om pakketten, omgevingen en versies efficiënt te beheren. Traditionele tools zoals pip en conda hebben ons goed gediend, maar naarmate projecten complexer worden, nemen ook onze eisen toe. Deze Engelstalige gids kijkt naar moderne alternatieven - Poetry, PDM, Hatch en Rye - die elk unieke mogelijkheden bieden om Python projectbeheer te stroomlijnen.

Lees meer

Wat doet een (Cloud) Data Engineer versus een Machine Learning Engineer?

In de wereld van data en technologie zijn Data Engineers en Machine Learning Engineers cruciale spelers. Beide rollen zijn essentieel voor het ontwerpen, bouwen en onderhouden van moderne data-infrastructuren en geavanceerde machine learning (ML) toepassingen. In deze blog focussen we specifiek op de taken en verantwoordelijkheden van een Data Engineer en Machine Learning Engineer.

Lees meer

Hoe werkt de AI Document Explorer in de praktijk?

De AI Document Explorer (AIDE) is een cloudoplossing, ontwikkeld door Digital Power, die gebruik maakt van het OpenAI’s GPT-model. Je kunt het inzetten om snel inzicht te krijgen in bedrijfsdocumenten. AIDE indexeert jouw bestanden op een veilige manier waardoor het mogelijk wordt om vragen te stellen over jouw eigen documenten. Niet alleen geeft het jou de antwoorden waar je naar op zoek bent, het geeft ook de referenties naar de plekken waar deze antwoorden staan.

Lees meer

Snelle en betrouwbare interne informatie met behulp van AI Document Explorer

Financiële instellingen moeten grote hoeveelheden documentatie verwerken. Voor deze specifieke instelling faciliteert een intern team dit door bijvoorbeeld samenvattingen te maken met behulp van tekstanalyse en natural language processing (NLP). Deze maken ze beschikbaar voor de verschillende business units. Om audits efficiënter uit te voeren, wilden ze een vraag- en antwoordmodel ontwikkelen om sneller de juiste informatie tot hun beschikking te hebben. Toen ChatGPT werd gelanceerd, vroegen ze ons een proof of concept te maken.

Lees meer

Een dataplatform implementeren

Deze blog is bedoeld om onze kennis en ervaring over te dragen aan de gemeenschap door richtlijnen te beschrijven voor de implementatie van een dataplatform in een organisatie, gebaseerd op onze knowhow. We weten dat de specifieke behoeften van elke organisatie anders zijn, dat ze een impact zullen hebben op de gebruikte technologieën en dat één enkele architectuur die aan al deze behoeften voldoet, niet realistisch is. Daarom houden we het in deze blog zo algemeen mogelijk.

Lees meer

Efficiënter werken dankzij migratie naar Databricks

Het Kadaster beschikt onder andere over complexe (geo)data van al het vastgoed in Nederland. Alle data wordt opgeslagen en verwerkt via een on-premise data warehouse in Postgres. Voor het onderhoud van dit warehouse zijn ze afhankelijk van een IT-partner. Het Kadaster wil kosten besparen en efficiënter gaan werken door te migreren naar een Databricks-omgeving. Ze vroegen ons te helpen bij de implementatie van dit data lakehouse in Microsoft Azure Cloud.

Lees meer

Miljarden streams omgezet in bruikbare inzichten met een nieuw data- en analytics platform

Merlin is de grootste digitale muzieklicentiepartner voor onafhankelijke labels, distributeurs en andere rechthebbenden. De leden van Merlin vertegenwoordigen 15% van de wereldwijde markt voor muziekopnames. Het bedrijf heeft overeenkomsten met Apple, Facebook, Spotify, YouTube en 40 andere innovatieve digitale platforms over de hele wereld voor de opnames van haar leden. Het team van Merlin volgt betalingen en gebruiksrapporten van digitale partners nauwlettend en zorgt ervoor dat hun leden nauwkeurig, efficiënt en consistent worden betaald en van rapportages worden voorzien.

Lees meer

20% minder klachten dankzij datagedreven onderhoudsrapportages

Een belangrijk onderdeel van de bedrijfsvoering van Otis is het onderhoud van hun liften. Om dit goed te timen en klanten proactief te informeren over de status van hun lift, wilde Otis continue monitoring inzetten. Ze zagen veel potentie in predictive maintenance en onderhoud op afstand.

Lees meer

Kubernetes-based event-driven autoscaling met KEDA: een praktische gids

In dit Engelstalige artikel beginnen we met een uitleg van wat Kubernetes Event Driven Autoscaling (KEDA) inhoudt. Vervolgens richten we een lokale ontwikkelomgeving in die het mogelijk maakt om KEDA te demonstreren met behulp van Docker en Minikube. Daarna leggen we het scenario uit dat geïmplementeerd zal worden om KEDA te demonstreren, en doorlopen we dit scenario stap voor stap. Aan het einde van het artikel heeft de lezer een duidelijk beeld van wat KEDA is en hoe hij of zij zelf een architectuur met KEDA kan implementeren.

Lees meer

Azure App functions configureren

In dit Engelstalige artikel beginnen we met het bespreken van Serverless Functions. Vervolgens demonstreren we hoe je Terraform-bestanden gebruikt om het implementatieproces van een doelinfrastructuur te vereenvoudigen, hoe een Function App in Azure kan worden gemaakt, het gebruik van GitHub-workflows om continuous integration en implementatie te beheren, en hoe branching strategieën kunnen worden gebruikt om code wijzigingen selectief uit te rollen naar specifieke instanties van Function Apps.

Lees meer

Opzet van een toekomstbestendige data-infrastructuur

Valk Exclusief is een keten van 4 sterren+ hotels en heeft 43 hotels in Nederland. De hotelketen wil gasten graag een persoonlijke ervaring bieden, zowel in het hotel als online.

Lees meer

Een dag in het leven van een Data Engineer

Voor het ontwikkelen van moderne datatoepassingen is de Data Engineer onmisbaar. Maar wat betekent het eigenlijk om Data Engineer te zijn en wat doe je dan precies? Onze collega Oskar, Data Engineer bij Digital Power, legt het je uit.

Lees meer

Een goed georganiseerde data-infrastructuur

FysioHolland is een overkoepelende organisatie voor fysiotherapeuten in Nederland. Een centraal serviceteam ontlast therapeuten van bijkomende werkzaamheden, zodat zij zich vooral kunnen focussen op het leveren van de beste zorg. Naast de organische groei sluit FysioHolland nieuwe praktijken aan bij de organisatie. Deze hebben stuk voor stuk hun eigen systemen, werkprocessen en behandelcodes. Dit heeft de datahuishouding van FysioHolland groot en complex gemaakt.

Lees meer

Een schaalbaar machine learning-platform voor het voorspellen van billboard-impressies

The Neuron biedt een programmatisch biedingsplatform om digitale Out-Of-Home-advertenties in realtime te plannen, kopen en beheren. Ze vroegen ons het aantal verwachte impressies voor digitale advertenties op billboards op een schaalbare en efficiënte manier te voorspellen.

Lees meer