Stable Diffusion XL for JAX + TPUv5e

TPU v5e is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.

JAX is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:

👉 Try it out for yourself:

Hugging Face Spaces

Stable Diffusion XL pipeline in JAX

Upon having access to a TPU VM (TPUs higher than version 3), you should first install a TPU-compatible version of JAX:

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Next, we can install flax and the diffusers library:

pip install flax diffusers transformers

In sdxl_single.py we give a simple example of how to write a text-to-image generation pipeline in JAX using StabilityAI’s Stable Diffusion XL.

Let’s explain it step-by-step:

Imports and Setup

import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline

from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("/tmp/sdxl_cache")
import time

NUM_DEVICES = jax.device_count()

First, we import the necessary libraries:

1. Downloading Model and Loading Pipeline

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
)

Here, a pre-trained model stable-diffusion-xl-base-1.0 from the namespace stabilityai is loaded. It returns a pipeline for inference and its parameters.

2. Casting Parameter Types

scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state

This section adjusts the data types of the model parameters. We convert all parameters to bfloat16 to speed-up the computation with model weights. Note that the scheduler parameters are not converted to blfoat16 as the loss in precision is degrading the pipeline’s performance too significantly.

3. Define Inputs to Pipeline

default_prompt = ...
default_neg_prompt = ...
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25

Here, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.

4. Tokenizing Inputs

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

This function tokenizes the given prompts. It’s essential because the text encoders of SDXL don’t understand raw text; they work with numbers. Tokenization converts text to numbers.

5. Parallelization and Replication

p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    ...

To utilize JAX’s parallel capabilities, the parameters and input tensors are duplicated across devices. The replicate_all function also ensures that every device produces a different image by creating a unique random seed for each device.

6. Putting Everything Together

def generate(...):
    ...

This function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that’s more interpretable (PIL format).

7. Compilation Step

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

The initial run of the generate function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.

8. Fast Inference

start = time.time()
prompt = ...
neg_prompt = ...
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")

Now that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.

In summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX’s capabilities.

Ahead of Time (AOT) Compilation

FlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let’s build parallelization ourselves.

For this we will be using a JAX feature called Ahead of Time (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.

In sdxl_single_aot.py we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using StabilityAI’s Stable Diffusion XL

We add a aot_compile function that compiles the pipeline._generate function telling JAX which input arguments are static, that is, arguments that are known at compile time and won’t change. In our case, it is num_inference_steps, height, width and return_latents.

Once the function is compiled, these parameters are omitted from future calls and cannot be changed without modifying the code and recompiling.

def aot_compile(
        prompt=default_prompt,
        negative_prompt=default_neg_prompt,
        seed=default_seed,
        guidance_scale=default_guidance_scale,
        num_inference_steps=default_num_steps
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
    g = g[:, None]

    return pmap(
        pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]
        ).lower(
            prompt_ids,
            p_params,
            rng,
            num_inference_steps, # num_inference_steps
            height, # height
            width, # width
            g,
            None,
            neg_prompt_ids,
            False # return_latents
            ).compile()

Next we can compile the generate function by executing aot_compile.

start = time.time()
print("Compiling ...")
p_generate = aot_compile()
print(f"Compiled in {time.time() - start}")

And again we put everything together in a generate function.

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
    g = g[:, None]
    images = p_generate(
        prompt_ids,
        p_params,
        rng,
        g,
        None,
        neg_prompt_ids)

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

The first forward pass after AOT compilation still takes a while longer than subsequent passes, this is because on the first pass, JAX uses Python dispatch, which Fills the C++ dispatch cache. When using jit, this extra step is done automatically, but when using AOT compilation, it doesn’t happen until the function call is made.

start = time.time()
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
neg_prompt = "cartoon, illustration, animation. face. male, female"
images = generate(prompt, neg_prompt)
print(f"First inference in {time.time() - start}")

From this point forward, any calls to generate should result in a faster inference time and it won’t change.

start = time.time()
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
neg_prompt = "cartoon, illustration, animation. face. male, female"
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")