Back to Articles

Paxml: Google's JAX Framework for Training Models at Trillion-Parameter Scale

[ View on GitHub ]

Paxml: Google's JAX Framework for Training Models at Trillion-Parameter Scale

Hook

While most ML frameworks celebrate reaching 50% hardware utilization, Google's Paxml routinely achieves over 60% Model FLOPs Utilization on TPU pods—a difference that translates to millions in compute savings when training trillion-parameter models.

Context

Training large language models is an exercise in fighting inefficiency. When you're running a model with hundreds of billions of parameters across hundreds of accelerators, every percentage point of hardware utilization matters. A 10% improvement in Model FLOPs Utilization (MFU) on a multi-month training run can mean the difference between a viable project and one that burns through your compute budget before convergence.

Most popular frameworks were designed when "large" meant models with tens of millions of parameters. PyTorch and TensorFlow added distributed training capabilities as an afterthought, resulting in complex APIs that require developers to manually orchestrate parallelism strategies. Google built Paxml (Pax) from the ground up for a different reality: one where models routinely exceed 100 billion parameters and training runs consume entire TPU pods for weeks. By embracing JAX's functional paradigm and SPMD parallelism, Pax treats distributed training not as a feature bolted on, but as the fundamental design constraint. The result is a framework that has powered some of Google's largest language model training runs with MFU rates that make other frameworks look wasteful.

Technical Insight

Paxml's architecture revolves around three core principles: configuration-as-code, SPMD-first parallelism, and TPU-native optimization. Understanding these principles is essential to grasping why Pax achieves such exceptional efficiency.

The configuration-as-code approach means experiments are defined as structured Python classes rather than scattered command-line flags or YAML files. Here's what a basic GPT-style language model configuration looks like:

from paxml import experiment_registry
from paxml import tasks_lib
from praxis import base_layer
from praxis import pax_fiddle
from praxis.layers import transformers

@experiment_registry.register
class GPT3Small(tasks_lib.SingleTask):
  """GPT-3 Small (125M parameters) configuration."""
  
  def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask.Task]:
    task_p = pax_fiddle.Config(tasks_lib.LanguageModelTask)
    task_p.model = pax_fiddle.Config(transformers.TransformerLm)
    
    # Model architecture
    model_p = task_p.model
    model_p.vocab_size = 50257
    model_p.num_layers = 12
    model_p.model_dim = 768
    model_p.num_heads = 12
    model_p.ff_dim = 3072
    
    # Parallelism strategy
    model_p.mesh_axis_names = ['replica', 'data', 'mdl']
    model_p.ici_mesh_shape = [1, 4, 1]  # 4-way data parallelism
    
    # Weight sharding annotations
    model_p.params_init = WeightInit.Gaussian(0.02)
    model_p.attention_layer_tpl.weight_split_dims_mapping = {
        'wq': ['mdl', 'data'],
        'wk': ['mdl', 'data'],
        'wv': ['mdl', 'data'],
        'wo': ['data', 'mdl'],
    }
    
    return task_p

This configuration explicitly declares the parallelism mesh—how accelerators are logically organized—and how each weight tensor should be sharded across that mesh. The mesh_axis_names define logical axes ('replica' for model replication, 'data' for data parallelism, 'mdl' for model parallelism), while ici_mesh_shape maps these to physical device topology. This declarative approach means you can change from 4-way data parallelism to 2-way data and 2-way model parallelism by simply modifying the mesh shape, without touching the model code.

The SPMD parallelism, powered by JAX's pjit, is where Pax's efficiency truly shines. Unlike traditional frameworks where you manually split batches or insert communication collectives, SPMD compiles a single program that runs on all devices simultaneously. The XLA compiler automatically inserts the necessary all-reduce, all-gather, and reduce-scatter operations based on your sharding annotations. This eliminates an entire class of bugs where manual parallelization introduces subtle communication errors.

For very large models, Pax supports pipeline parallelism layered on top of SPMD. Here's how you'd configure a model to use both model parallelism and pipeline parallelism:

model_p.num_pipeline_stages = 4
model_p.pipeline_stage_partition_spec = [
    ('layers_0_2', [0]),   # Layers 0-2 on stage 0
    ('layers_3_5', [1]),   # Layers 3-5 on stage 1
    ('layers_6_8', [2]),   # Layers 6-8 on stage 2
    ('layers_9_11', [3]),  # Layers 9-11 on stage 3
]
model_p.circular_repeat = 2  # Microbatching for pipeline efficiency

The framework handles the complex micro-batching and gradient accumulation required to keep all pipeline stages busy, a technique called "pipeline bubble reduction" that's critical for high MFU.

Pax's TPU-native optimizations go deep. It uses activation rematerialization (recomputing activations during backward pass instead of storing them) strategically, based on profiling data about which operations are memory-bandwidth bound versus compute-bound on TPUs. The checkpointing system integrates with TensorStore, enabling asynchronous checkpoint writes that don't block training—essential when your model state is hundreds of gigabytes. The framework also supports bfloat16 training with mixed precision handled automatically, and integrates with TPU's optimized collective communication primitives.

The training loop itself is pure JAX, compiled with XLA for maximum performance. Pax pre-compiles not just the forward and backward passes but also the optimizer step and metric computation, eliminating Python overhead almost entirely during training. This is why Pax can achieve 60%+ MFU while frameworks with eager-mode components often plateau at 45-50%.

Gotcha

Paxml's exceptional performance comes with significant accessibility trade-offs that will frustrate teams not already deep in the Google ecosystem. The learning curve is genuinely steep—you need working knowledge of JAX's functional programming model, comfort with SPMD sharding concepts, understanding of how pjit differs from pmap, and familiarity with Google Cloud TPU infrastructure. The documentation assumes this background, with examples that jump straight into mesh partitioning without explaining the fundamentals. If you're coming from PyTorch, expect weeks of ramp-up time before you're productive.

The TPU lock-in is real and consequential. While NVIDIA has created a fork with H100 FP8 support, it's a separate codebase that lags behind Google's releases. If you're primarily a GPU shop without access to TPU pods, you're fighting against the framework's design philosophy. The performance optimizations that make Pax shine on TPU v4 and v5e—like specific collective communication patterns and memory layout assumptions—don't translate directly to GPU architectures. Community support is limited compared to PyTorch-based frameworks; you'll find far fewer StackOverflow answers, blog posts, and third-party tutorials. With only 550 GitHub stars, you're often on your own when debugging issues. The integration with popular tools like Weights & Biases or Hugging Face datasets requires custom glue code that the community hasn't standardized yet.

Verdict

Use if: You're training models above 10 billion parameters on Google Cloud TPU infrastructure and hardware efficiency directly impacts your project viability. Pax's MFU advantage compounds dramatically at scale—the difference between 55% and 65% utilization on a $500K training run is $75K in compute savings. It's ideal for research labs and companies running multi-week training jobs where squeezing every FLOP matters more than developer convenience. Choose it if you're already committed to JAX and want the most battle-tested large-scale training framework in that ecosystem.

Skip if: You're working with models under 10B parameters, need GPU-first support without forking to NVIDIA's branch, prefer PyTorch's mature ecosystem and gentler learning curve, or lack team expertise in JAX and functional programming. For most teams, PyTorch FSDP or DeepSpeed provides 80% of the performance with 20% of the complexity. Also skip if you need rapid experimentation—Pax's compilation-heavy approach means iteration cycles are slower than eager-mode frameworks, making it poorly suited for exploratory research where you're frequently changing model architectures.

// ADD TO YOUR README
[![Featured on Starlog](https://starlog.is/api/badge/llm-engineering/google-paxml.svg)](https://starlog.is/api/badge-click/llm-engineering/google-paxml)