CAGE: Curvature-Aware Gradient Estimation For Accurate Quantization-Aware Training

Published in arXiv pre-print, 2025

Recommended citation: S. Tabesh, M. Safaryan, A. Panferov, A. Volkova, D. Alistarh. (2025). "CAGE: Curvature-Aware Gradient Estimation For Accurate Quantization-Aware Training." arXiv pre-print. https://arxiv.org/abs/2510.18784

CAGE (Curvature-Aware Gradient Estimation) augments the straight-through estimator with a curvature-aware correction designed to counteract the loss increase induced by quantization. Derived from a multi-objective view of QAT that balances loss minimization with quantization constraints, CAGE yields strong convergence guarantees in the smooth non-convex setting. For QAT fine-tuning, it halves the compression accuracy loss relative to prior best methods, while for QAT pre-training of Llama models, its 3-bit weights-and-activations (W3A3) accuracy matches 4-bit (W4A4) accuracy of prior methods.

Access paper here