Get started with Velocity
Join the Waitlist
Join Our Discord
Blogs

Deploy a PyTorch ML Model in K8s and Cache Results in Redis

Jeff Vincent
Jeff Vincent
  
December 23, 2022

ML elements of applications often require a significant amount of processing time. Learn how to speed up your microservice-based ML app with asynchronous networking and in-memory data caching in this post.

Deploy a PyTorch ML Model in K8s and Cache Results in Redis

Machine learning processes can be very time intensive, which can result in users waiting longer than they are accustomed to for a response from a given application. This can be mitigated in several ways. For example, increasing the processing power of a server dedicated to carrying out the ML task associated with your app can help quite a lot to speed things up. Additionally, there are architectural and design elements that you can incorporate into your application that will ensure that it runs as quickly as possible.

Today, we’ll look at two such design considerations – leveraging FastAPI for asynchronous microservices networking in Python and caching computed results in Redis to ensure that your ML model only has to carry out as few of these time intensive processes as possible.

Topics we’ll cover

Project Overview

Our project will consist of two FastAPI-driven APIs. The first will handle incoming web traffic, and the second will host our ML model – a PyTorch implementation of a Resnet-50 image classification model which has been pre-trained with the imagenet dataset. Resnet is an example of a convolutional neural network (CNN) and was a 2015 award winner for its remarkable accuracy.

The application will allow users to upload a given image and submit it for the ML model to classify. However, before the Web API passes the request to the PyTorch service, it will first check to see if that specific image binary has been processed before by querying Redis for cached data.

If it has been processed before, it won’t be processed again – rather, the cached data will simply be returned directly to the user. Otherwise, the Web API will send an asynchronous request to the PyTorch service for the image to be classified. Before returning the result, the PyTorch service will cache the resulting data in Redis for future requests.

Once we have developed the above functionality, we’ll deploy the application in Kubernetes.

Workflow diagram

The full project is available in GitHub.

Web API

Below, we have a simple API that handles two varieties of requests. It serves an index page, which will allow users to upload a given image to be classified, and it has an endpoint defined to handle this file upload process.

@app.on_event('startup')
async def initialize():
  pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
  global REDIS
  REDIS = redis.Redis(connection_pool=pool)

@app.get('/')
async def index_view():
  return HTMLResponse("""
      <div style="background-color: #707bb2; margin: 15px; border-radius: 5px; padding: 15px; width: 300px">
      <b>Upload an image: </b>
      <form action="/classify" method="post" enctype="multipart/form-data">
          <p><input type=file name=file value="Pick an image">
          <p><input type=submit value="Upload">
      </form>
      </div>""")

@app.post('/classify')
async def classify_image(file: bytes = File()):
  cached_data = await check_for_cached(file)
  if cached_data == None:
      form = aiohttp.FormData()
      form.add_field('data', file)
      try:
          async with aiohttp.ClientSession() as session:
              async with session.post(f'http://{PYTORCH_HOST}:{PYTORCH_PORT}/classify', data=form) as response:
                  r = await response.text()
                  data = json.loads(r)
                  return data
      except Exception as e:
          return HTMLResponse(f'<h3>Error:{str(e)}</h3>')
  return cached_data

async def check_for_cached(file):
  hash = zlib.adler32(file)
  data = REDIS.get(hash)
  if data:
      return json.loads(data)
  return None

Notice that the /classify endpoint accepts a file upload, and then calls the check_for_cached() function. This function queries our Redis instance with a “fingerprint” hash of the uploaded image’s binary, which will have been set as a key in Redis by the PyTorch service if that specific binary has been processed previously. If it is found, that cached result is returned directly to the user from the Web API.

If that fingerprint is not found, an asynchronous request containing the uploaded file is forwarded to the PyTorch service with an aiohttp post request. By making this request asynchronous, it allows the Web API to continue handling incoming web traffic while it waits for the response from the PyTorch service.

PyTorch Service

Next, we’ll define our PyTorch service – also as a FastAPI app. The API will handle requests to a single endpoint /classify, which will then call a series of functions in order to pre-process the image, classify it, cache the result in Redis, and finally return the computed result to the Web API to be returned to the user.

Notice that this service, like the one we just defined, includes an @app.on_event('startup') event handler. In both services we are making the Redis connection at application startup, which is important for FastAPI to be able to access the connection, as it is in the main Asyncio event loop. To learn more about this, check out the post Asynchronous Video Streaming in Python with FastAPI and MongoDB’s GridFS.

Here, in addition to making the connection to Redis, we are also initializing our model. Notice that we are defining the model and the labels as global variables so that they can easily be accessed by all functions that interact with them.

@app.on_event('startup')
async def initialize():
  weights = ResNet50_Weights.DEFAULT
  global MODEL
  global LABELS
  with open("imagenet_class_index.json", 'r') as f:  
      class_idx = json.load(f)
      LABELS = [class_idx[str(k)][1] for k in range(len(class_idx))]
  MODEL = resnet50(weights=weights)
  MODEL.eval()
  pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
  global REDIS
  REDIS = redis.Redis(connection_pool=pool)

def _preprocess_image(img):
  img_pil = Image.open(BytesIO(img)).convert('RGB')
  imagenet_mean = [0.485, 0.456, 0.406]
  imagenet_std = [0.229, 0.224, 0.225]
  t = transforms.Compose(
              [transforms.Resize(256),
               transforms.CenterCrop(224),
               transforms.ToTensor(),
               transforms.Normalize(mean=imagenet_mean, std=imagenet_std)]
          )
  img_tensor = t(img_pil)
  return torch.unsqueeze(img_tensor, 0)

def _classify_image(preprocessed_image):
  global MODEL
  out = MODEL(preprocessed_image)
  _, index = torch.max(out, 1)
  pct = F.softmax(out, dim=1)[0] * 100
  return (LABELS[index[0]], pct[index[0]].item())

@app.post('/classify')
async def image(data: bytes = File()):
  preprocessed_image = _preprocess_image(data)
  result = _classify_image(preprocessed_image)
  await write_to_cache(data, result)
  return JSONResponse({result[0]: str(result[1])})

async def write_to_cache(file, result):
  hash = zlib.adler32(file)
  REDIS.set(hash, json.dumps({result[0]: str(result[1])}))

When the /classify endpoint receives a request with an image to be classified, the FastAPI File() class stores the image as binary in a variable called data. That variable is then passed to the _pre-process_image() function, where it is resized, cropped and converted to a tensor. That tensor is then passed to our ML model for processing. Once processed, the result is written to Redis as described above, and the result is returned.

Dockerize the services

With our services defined, it’s time to Dockerize them so they can be deployed in K8s. Notice that we are using an Alpine image to keep the resulting containers as small as possible, and thus as responsive as possible – again, in the interest of offsetting time-intensive processes.

FROM python:3.10-alpine
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY main.py .
CMD ["sh", "-c", "uvicorn main:app --host 0.0.0.0 --port 8000 --reload"]

Deploy in K8s

With our containers pushed to a registry, such as Dockerhub,  we can now move on to defining the Kubernetes manifests required to deploy this app in K8s. Specifically, we will need to define a Deployment and a ClusterIP service for each of our three microservices – the Web API, the PyTorch service, and Redis. Additionally, we’ll need to define an ingress for the Web API, so that it will be able to handle incoming web traffic.

Below, we have all three of these resources that relate to the Web API.

---
apiVersion: apps/v1
kind: Deployment
metadata:
name: web-api
labels:
  app: web-api
spec:
selector:
  matchLabels:
    api: web-api
replicas: 3
template:
  metadata:
    labels:
      app: web-api
      api: web-api
  spec:
    containers:
      - name: web-api
        image: jdvincent/pytorch-web-api:latest
        env:
          - name: PYTORCH_HOST
            value: {{ .Values.pytorch_host | quote }}
          - name: PYTORCH_PORT
            value: {{ .Values.pytorch_port | quote }}
          - name: REDIS_HOST
            value: {{ .Values.redis_host | quote  }}
          - name: REDIS_PORT
            value: {{ .Values.redis_port | quote  }}
        ports:
          - name: web-api
            containerPort: 8000
            protocol: TCP
---
apiVersion: v1
kind: Service
metadata:
name: web-api
spec:
ports:
  - port: 8000
    targetPort: 8000
    name: web-api
selector:
  app: web-api
type: ClusterIP
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: web-api
spec:
ingressClassName: {{ .Values.ingress_class_name | quote }}
rules:
  - host: {{ .Values.ingress_host | quote }}
    http:
      paths:
        - path: /
          pathType: Prefix
          backend:
            service:
              name: web-api
              port:
                number: 8000

Notice that some of the values are templates rather than hard-coded values. This is because this is actually a Helm template, rather than a vanilla K8s resource definition. These templates allow us to easily populate K8s manifests with different sets of values, which is very useful for deploying the same app to multiple environments, as we’ll see below.

Deploy in Minikube

Because we defined our manifests as Helm templates, we can deploy them to Minikube by passing the following file to our Helm command when we create a template.

values.yaml

pytorch_host: pytorch
pytorch_port: "8000"
redis_port: "6379"
redis_host: redis
ingress_host: null
ingress_class_name: kong

But first, we’ll need to start Minikube and enable the Kong ingress controller add-on, like so:

minikube start
minikube addons enable kong
minikube tunnel

Then, we can deploy the app by creating a template (which contains the above values) and piping the output to kubectl, like so:

helm template . --values values.yaml | kubectl apply -f -

Conclusion

Applications that contain ML elements can be made more performant by incorporating architectural and design elements that ensure the fewest number of ML operations are carried out as possible, as well as reducing any potential latency in the services that make up the remainder of the application.

We demonstrated both of these approaches above, by incorporating Redis caching of ML results and asynchronous microservices networking with FastAPI and Asyncio.

Join the discussion!

Have any questions or comments about this post? Maybe you have a similar project or an extension to this one that you'd like to showcase? Join the Velocity Discord server to ask away, or just stop by to talk K8s development with the community.

Python class called ProcessVideo

Python class called ProcessVideo

Get started with Velocity