The challenge is to run Stable Diffusion, which includes a large transformer model with almost 1 billion parameters, on a Raspberry Pi Zero 2, which is a microcomputer with 512MB of RAM, without adding more swap space and without offloading intermediate results on disk. The recommended minimum RAM/VRAM for Stable Diffusion is typically 8GB.
Generally major machine learning frameworks and libraries are focused on minimizing inference latency and/or maximizing throughput, all of which at the cost of RAM usage. So I decided to write a super small and hackable inference library specifically focused on minimizing memory consumption: OnnxStream.
OnnxStream is based on the idea of decoupling the inference engine from the component responsible of providing the model weights, which is a class derived from WeightsProvider
. A WeightsProvider
specialization can implement any type of loading, caching and prefetching of the model parameters. For example a custom WeightsProvider
can decide to download its data from an HTTP server directly, without loading or writing anything to disk (hence the word “Stream” in “OnnxStream”). Two default WeightsProviders
are available: DiskNoCache
and DiskPrefetch
.
OnnxStream can consume even 55x less memory than OnnxRuntime while being only 0.5-2x slower (on CPU, see the Performance section below).
These images were generated by the Stable Diffusion example implementation included in this repo, using OnnxStream, at different precisions of the VAE decoder. The VAE decoder is the only model of Stable Diffusion that could not fit into the RAM of the Raspberry Pi Zero 2 in single or half precision. This is caused by the presence of residual connections and very big tensors and convolutions in the model. The only solution was static quantization (8 bit). The third image was generated by my RPI Zero 2 in about 3 hours. The first image was generated on my PC using the same latents generated by the RPI Zero 2, for comparison:
VAE decoder in W16A16 precision:
VAE decoder in W8A32 precision:
VAE decoder in W8A8 precision (generated by my RPI Zero 2 in about 3 hours):
- Inference engine decoupled from the
WeightsProvider
WeightsProvider
can beDiskNoCache
,DiskPrefetch
or custom- Attention slicing
- Dynamic quantization (8 bit unsigned, asymmetric, percentile)
- Static quantization (W8A8 unsigned, asymmetric, percentile)
- Easy calibration of a quantized model
- FP16 support (with or without FP16 arithmetic)
- 24 ONNX operators implemented (the most common)
- Operations executed sequentially but all operators are multithreaded
- Single implementation file + header file
- XNNPACK calls wrapped in the
XnnPack
class (for future replacement)
OnnxStream depends on XNNPACK for some (accelerated) primitives: MatMul, Convolution, element-wise Add/Sub/Mul/Div, Sigmoid and Softmax.
Stable Diffusion consists of three models: a text encoder (672 operations and 123 million parameters), the UNET model (2050 operations and 854 million parameters) and the VAE decoder (276 operations and 49 million parameters). Assuming that the batch size is equal to 1, a full image generation with 10 steps, which yields good results (with the Euler Ancestral scheduler), requires 2 runs of the text encoder, 20 (i.e. 2*10) runs of the UNET model and 1 run of the VAE decoder.
This table shows the various inference times of the three models of Stable Diffusion, together with the memory consumption (i.e. the Peak Working Set Size
in Windows or the Maximum Resident Set Size
in Linux).
Model / Library | 1st run | 2nd run | 3rd run |
---|---|---|---|
FP16 UNET / OnnxStream | 0.133 GB – 18.2 secs | 0.133 GB – 18.7 secs | 0.133 GB – 19.8 secs |
FP16 UNET / OnnxRuntime | 5.085 GB – 12.8 secs | 7.353 GB – 7.28 secs | 7.353 GB – 7.96 secs |
FP32 Text Enc / OnnxStream | 0.147 GB – 1.26 secs | 0.147 GB – 1.19 secs | 0.147 GB – 1.19 secs |
FP32 Text Enc / OnnxRuntime | 0.641 GB – 1.02 secs | 0.641 GB – 0.06 secs | 0.641 GB – 0.07 secs |
FP32 VAE Dec / OnnxStream | 1.004 GB – 20.9 secs | 1.004 GB – 20.6 secs | 1.004 GB – 21.2 secs |
FP32 VAE Dec / OnnxRuntime | 1.330 GB – 11.2 secs | 2.026 GB – 10.1 secs | 2.026 GB – 11.1 secs |
In the case of the UNET model (when run in FP16 precision, with FP16 arithmetic enabled in OnnxStream), OnnxStream can consume even 55x less memory than OnnxRuntime while being 0.5-2x slower.
Notes:
- The first run for OnnxRuntime is a warm up inference, since its
InferenceSession
is created before the first run and reused for all the subsequent runs. No such thing as a warm up exists for OnnxStream since it is purely eager by design (however subsequent runs can benefit from the caching of the weights files by the OS). - At the moment OnnxStream doesn’t support inputs with a batch size != 1, unlike OnnxRuntime, which can greatly speed up the whole diffusion process using a batch size = 2 when running the UNET model.
- In my tests, changing OnnxRuntime’s
SessionOptions
(likeEnableCpuMemArena
andExecutionMode
) produces no significant difference in the results. - Performan