I am a huge admirer of SemiAnalysis, and this project was originally meant to be an InferenceX for robotics and world models. I was ultimately too GPU-constrained to execute the full vision, but I did get some interesting results. Given my constraints, a lot of my focus shifted toward understanding different inference optimizations which providers may run. This blog is an explanation of those optimization techniques, as well as the results I got from implementing these optimizations. You can find my full repo here.
The model I worked with was from NVIDIA’s Cosmos World Model family. In particular, I used predict2-2b-video2world. Beyond NVIDIA’s native sparse-attention checkpoint for this model, things I implemented included precision changes for some operations, caching features for certain hidden layers, and CUDA graphs.
The Cosmos World Model is a diffusion transformer. At its core, the Cosmos video2world model takes in a text prompt and an image (or a couple video frames), and outputs a short video. The text prompt and the image condition the video output.
Here is how the model works:
video VAE, latent grid, DiT tokens
Internally, the video canvas is represented as a 3-channel RGB video with 93 frames and pixel dimensions of 704x1280 (height/width). We then have an encoder which converts this raw video into a latent video with 16 channels, 24 latent time positions, and a spatial grid of 88x160. Unlike an LLM tokenizer, this video tokenizer does not produce discrete vocabulary tokens, and instead produces continuous latent features. After being converted to the latent representation, it is split into many smaller patch tokens, and each patch token is projected up as a 2048-dimensional vector.
These pieces are learned, but not all in the same training run. The VAE/tokenizer is trained separately to compress and reconstruct videos. The text prompt is tokenized and encoded by Google’s T5 encoder, which was already pretrained rather than trained from scratch for this project. T5 outputs a sequence of contextual text embeddings, roughly 512 vectors of 1024 dimensions, which the video model later uses through cross-attention.
The model, instead of trying to predict one token at a time, tries to generate the entire video at once. The first couple frames are fixed as the input video, and the remaining frames of the video are initialized as random noise. The job of the model is to use the initial frames and the text prompt to “denoise” the future frames so that they have meaning.
The core model runs 35 “sampling” or “denoising steps” over this initialized video. For each step, we have 28 transformer blocks, and each block has 1 self-attention, 1 cross-attention, and 1 multi-layer perceptron layer. The self-attention block has video tokens attend to other video tokens, the cross-attention block has the video tokens attend to text tokens, and finally the multi-layer perceptron transforms each token independently. Around each sublayer, we have LayerNorm and AdaLN conditioning, which uses the denoising timestep to shift, scale, and gate the block’s activations. Each block has trained weights different from every other block, but each denoising step is running the same set of blocks.
latent-video DiT denoising loop
There are a few things we immediately notice. We have many layers! There is bound to be some redundancy when we have 28 DiT blocks. At that scale, we have sufficient depth in our model that we can remove some of this redundancy and make the model more efficient. One way to do this is to implement forms of “sparse attention”. Instead of having every token attend to every other token, we can have each token attend to only some tokens. This technique has gained significant popularity, with one of Deepseek’s core breakthroughs being a new sparse attention mechanism.
This intuitively makes sense for a video model. Tokens near in height, width, and time matter significantly to each token. Tokens far away, at the top/edge of the video output, likely don’t have much action or relevance to whatever is happening in the centre. Given the number of layers we have available, we can try many things: we can have full, or “dense” layers, where every token attends to every other token, mixed with some sparse layers. For some layers, we can only choose to attend to tokens close in time, width, or height. We can even try diagonal lines and checkerboard grids.
Indeed, NVIDIA trained a separate checkpoint of this model with this sparse attention mechanism. The self-attention pattern varies per block. All 28 self-attention blocks use sparse NATTEN patterns; none are dense self-attention layers. They all attend across the full latent time dimension, but vary the height/width window. Early blocks use smaller or more dilated windows, such as 4x16 or 12x16. Many middle and later blocks use wider local windows, such as 20x40, and one block uses a 28x56 window. All cross-attention layers are still dense.
To implement this, NVIDIA used NATTEN. Standing for Neighbourhood Attention Extension, it provides custom CUDA kernels to efficiently run sparse-attention operations. It is essentially an API which takes PyTorch tensors as input, calls its custom CUDA/CUTLASS kernels, and then returns PyTorch tensors as output. From the perspective of the developer, it feels like a standard PyTorch operation:
out = neighborhood_attention_generic(
query=q,
key=k,
value=v,
kernel_size=window_size,
stride=stride,
dilation=dilation,
is_causal=is_causal,
)
This checkpoint was benchmarked by NVIDIA at 94.2 seconds for an H100 SXM, about 2.4x faster than the dense version. Moreover, NVIDIA saw virtually no quality losses for this checkpoint: its pBench scores were nearly identical. Sparse Attention was virtually free lunch here.
Other than sparse attention, a major optimization pathway we have is precision. NVIDIA’s newer classes of GPUs are increasingly optimized for lower precision, as they allow for faster operations: moving and multiplying 32-bit numbers is generally more expensive than doing the same with 8-bit numbers. A recurring theme on twitter are accusations of labs “quantizing” their models – i.e. making some parts of them run at lower precision (thus lower quality) to make them cheaper to serve.
NVIDIA natively operates this model on BF16. However, we can manually turn some operations into FP8 when useful. In particular, we get the biggest gains when we convert the most computationally-expensive operations to lower precision. Generally, the most computationally intensive operations are heavy GEMMs (general matrix-matrix multiplications).
To run operations in lower precision, we use NVIDIA’s TransformerEngine. Like NATTEN, it sits on top of PyTorch. Unlike PyTorch, TransformerEngine is optimized for lower-precision operations and abstracts away the need to deal with scaling parameters to lower precisions, deciding which specific FP8 format to use (E4M3 vs. E5M2), and customizing GPU kernels.
To implement FP8, we first wrap the entire forward-pass in TransformerEngine, and then selectively change some operations to use TransformerEngine instead of the native PyTorch implementation. In this case, we change “torch.nn.Linear” to “transformer_engine.pytorch.Linear” for example. Moreover, instead of asking TransformerEngine to repeatedly prepare the same weights for FP8 computation, we let it keep a lower-precision version of the selected weights ready to use. However, we only allow some operations to use these lower-precision weights, wherever they are the most helpful.
With quantization, we shaved another 10 seconds. My quantized run came in at 81.31 seconds, versus NVIDIA’s 94 seconds.
After this, we start getting into more experimental territory. Theoretically, we don’t actually need to run every block in every denoising step. In total, we are running 28 x 35 = 980 DiT blocks to generate each video. What harm does running only 970 do? One way to accomplish this is to simply run fewer denoising steps. Instead of running all 35 steps, we only run 30.
However, we can try another approach. Intuitively, the most important steps are the earliest ones – running them poorly messes up the rest of the denoising trajectory. Similarly, the end might be important; only in the final few blocks does our output become coherent. Therefore, we could try “skipping” some of the middle blocks in the middle steps. This is exactly what we can do by “caching” features. To do this, say, reuse the block 20 output from step 10 as block 20’s output in step 11.
I implemented it this way, but if we wish, we can also dynamically tune caching by checking if the two denoising calls are similar enough. We can also implement “drift-gating”, which prevents the cache from being reused if there is a large enough difference between the two denoising steps. Its algorithm looks something like this:
step 10:
input to block 20 = output of block 19
run block 20
save block 20 output
optionally save block 20 input for drift checking
step 11:
current input to block 20 = current output of block 19
compare current block-20 input with previous block-20 input
if similar enough:
reuse previous block-20 output
else:
recompute block 20
In terms of latency, we had a ~5% improvement over the fp8-only run, so a minor improvement. In terms of quality, the cached + fp8 run had 13.74% less movement than the fp8-only run, so the video was somewhat dampened and smoothened by the caching. Depending on the prompt, this can be good or bad. Good if the prompt mostly wants a static and stable scene, and bad if it wants one with a lot of movement. The reduction in movement intuitively makes sense: the model, with caching, has “fewer” chances to update and so add more motion. Caching biases the model to keep frames the same.
One interesting thing of note was that the luminosity and sharpness of the video were virtually unchanged. This might imply that the layers we cached were somehow learned to be responsible for motion or time, while the layers we didn’t cache handled sharpness.
The final thing I tried were CUDA graphs. Instead of speeding up or simplifying computation, CUDA graphs remove other overhead. Normally, our Python/PyTorch code on the CPU keeps asking the GPU to run many small and large kernels: linear layers, attention, normalization, activations, and so on. CUDA graphs let us record that repeated GPU work once, after warmup, and replay it on later calls with the same tensor shapes. This removes some of the repeated CPU-side overhead of launching the same pattern again and again.
With CUDA graphs, our initial steps are the same. However, once we have processed the input and are ready to run the GPU-intensive operations, we capture the repeated DiT sequence: the transformer blocks, final layer, and unpatchify step. The first time it sees a particular tensor shape/signature, it creates dummy tensors of the same shape, warms up the module, and records the GPU work into a CUDA graph. Later calls with the same signature replay the captured graph instead of asking PyTorch to launch the same long sequence of GPU kernels again.
When feature caching is also enabled, the implementation cannot graph the entire DiT as one uninterrupted chain, because some middle block computations may be skipped or reused. So it graphs the uncached block ranges as static segments, while the cached blocks still go through the cache wrapper.
After implementing CUDA graphs, our final, optimized time, with all other optimizations included, came out to 76.24 seconds, about 20% faster than NVIDIA’s baseline. Ultimately, the CUDA graphs didn’t shift the needle much. Our bottlenecks are the compute-intensive, GPU steps, and the overhead from the CPU to GPU back-and-forth isn’t that significant. Therefore our greatest speedups being from quantization and caching isn’t surprising. As CUDA graphs don’t change actual computation, they are mostly free lunch – the graphs ended up using about 3% more VRAM than the baseline.
It was a bunch of fun tinkering in the code and learning some ML Engineering. If you would like to explore more, you can check out my repo here. NVIDIA also has an updated version of this model, Predict 2.5, available. I personally used Modal for this project, but it shouldn’t be hard to adapt it to any other cloud provider.