Pax: Google’s Production Framework for Training Trillion-Parameter Models on TPUs
Hook
While most ML frameworks struggle to use more than 40% of TPU compute capacity, Pax consistently achieves the model FLOPs utilization rates discussed in Google’s PaLM paper—the same metric used to evaluate large-scale training efficiency.
Context
Training large language models is an exercise in wrestling with hardware inefficiency. A typical training run might use only 30-40% of your accelerator’s theoretical compute capacity, with the rest lost to communication overhead, memory bottlenecks, and suboptimal parallelization strategies. For organizations spending millions on TPU pods, every percentage point of hardware utilization directly impacts both training cost and time-to-convergence.
Google built Paxml (nicknamed Pax) to solve this problem at scale. Unlike research frameworks that prioritize flexibility, Pax is a production system designed from the ground up for one thing: extracting maximum performance from Cloud TPU infrastructure when training models from 1 billion to multi-billion parameters. The framework appears to be related to Google’s large language model research, including work on models like PaLM that demonstrated state-of-the-art results. Built on JAX with integration to the Praxis library, Pax takes a configuration-as-code approach where entire experiments—model architecture, optimizer settings, parallelization strategy, and data pipelines—are defined through Python classes that inherit from base configuration templates.
Technical Insight
Pax’s architecture revolves around declarative experiment configuration combined with JAX’s SPMD (Single Program Multiple Data) parallelization via pjit. Instead of imperatively defining training loops and manually managing device placement, you define an experiment by subclassing base configuration classes and specifying hyperparameters, model architecture, and sharding strategies.
Here’s how you’d configure and run a 1B parameter language model on a TPU v4-8:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://your-bucket
Under the hood, the C4Spmd1BAdam4Replicas configuration class defines the entire experiment: a transformer model with specific layer dimensions, attention heads, and embedding sizes; an Adam optimizer with learning rate schedule; data loading from the C4 dataset; and critically, the sharding annotations that tell JAX how to partition model weights and activations across TPU cores. The SPMD approach means you write the training code once, and JAX’s compiler automatically handles the distribution across 8, 128, or even thousands of TPU cores.
The framework’s power emerges when scaling beyond single-host TPUs. Pax supports sophisticated parallelization strategies including data parallelism (replicating the model across devices), model parallelism (sharding individual layers), and pipeline parallelism (distributing layers across devices). You control these through configuration rather than rewriting training code. For a 16B parameter model that won’t fit in a single TPU’s memory, you’d adjust sharding annotations in the config to specify which tensor dimensions get split across which device mesh axes.
Pax integrates TensorStore for checkpoint management, enabling efficient saving and loading of multi-terabyte model states directly to Google Cloud Storage buckets. This isn’t just convenience—at scale, checkpoint I/O becomes a significant bottleneck. TensorStore’s async writes and optimized GCS integration mean checkpointing doesn’t block training for extended periods. The framework also handles orchestration complexity of multi-host training: collective communications, barrier synchronization, and fault recovery across TPU pod slices.
The configuration-driven design has another benefit: reproducibility. Your entire experiment—including hyperparameters that actually worked—lives in version-controlled Python code rather than scattered across command-line flags and external config files. Teams can share proven configurations for specific model sizes and datasets, as demonstrated by Pax’s included C4 convergence runs for 1B and 16B parameter models.
For developers familiar with PyTorch or TensorFlow, Pax’s abstraction level sits higher in the stack. You’re not writing custom backward passes or manually managing gradient accumulation. Instead, you’re composing experiments from battle-tested components (attention layers, optimizers, data loaders) and focusing on the architectural decisions that actually matter at scale: how to shard your model, what parallelization strategy matches your hardware topology, and how to structure your data pipeline to keep TPUs fed.
Gotcha
Pax’s tight coupling to Google Cloud TPU infrastructure is both its strength and its primary limitation. While the README mentions GPU support exists, it immediately directs you to NVIDIA’s Rosetta fork—a tacit admission that GPU support is a second-class citizen. If you’re running on A100 or H100 clusters, you’re either using a third-party fork or choosing a different framework entirely. This isn’t a minor portability issue; the framework’s design assumptions—from checkpoint formats to parallelization primitives—reflect TPU characteristics.
The learning curve is steep. Pax expects you to understand JAX’s functional programming model, Praxis’s base configuration classes, and distributed training concepts like device meshes and sharding specs. The documentation includes Jupyter notebook tutorials, but jumping from those examples to configuring a custom model architecture requires reading through the Praxis codebase to understand which configuration knobs exist and how they interact. With only 550 GitHub stars, the community is relatively small—expect to read source code rather than finding abundant StackOverflow answers. For teams without existing JAX expertise, the ramp-up time is measured in weeks, not days.
Verdict
Use if: You’re training large language models (1B+ parameters) on Google Cloud TPUs and hardware utilization directly impacts your budget or research timeline. The framework excels when you need production-grade reliability for long convergence runs and want configuration-based reproducibility. It’s particularly valuable if you’re scaling experiments from smaller to larger model sizes using proven architectural templates. Skip if: You’re primarily using GPUs (use NVIDIA’s Rosetta fork or PyTorch FSDP instead), training models under 1B parameters where simpler frameworks suffice, need rapid prototyping with minimal configuration overhead, or want a large community for support. If you don’t have access to TPU infrastructure or aren’t already invested in the Google Cloud ecosystem, Pax’s benefits won’t outweigh the operational complexity of adopting a TPU-first framework.