<img height="1" width="1" style="display:none" src="https://www.facebook.com/tr?id=145304570664993&amp;ev=PageView&amp;noscript=1">
unit scaling for fp8 and fp16-1

Aug 30, 2023

Making FP16 and FP8 easy to use with our new unit scaling library

Written By:

Charlie Blake

Try AI notebooks for free

Try IPUs in the cloud with a zero set-up, pre-configured Jupyter development environment on Paperspace

Try now for free

Join the IPU conversation

Join our Graphcore community for free. Get help and share knowledge, find tutorials and tools that will help you grow.

Join on Slack

We’re pleased to announce the release of a PyTorch library to facilitate unit scaling — a method for designing models that makes low-precision number formats such as FP16 and FP8 easy to use.

In July, Graphcore researchers presented the paper Unit Scaling: Out-of-the-Box Low-Precision Training at ICML in Hawaii. We’re now releasing the software tools to make this method available to a wider audience.

The development of hardware with FP8 support, such as the Graphcore® C600 IPU-Processor PCIe Card, offers users substantial efficiency improvements. However, naïvely casting values from higher precision down into FP8 tends to degrade performance. Unit scaling addresses this, offering a simple path to making the most of FP8 hardware for training.

Check out the library documentation
Read our ICML paper

Demonstrating the library in action

To show users how to apply unit scaling to their own models, we’re also releasing a notebook to accompany the library. This demonstrates the training of the nanoGPT model in FP8 with and without unit scaling.

With only a single line of code — model = unit_scale(model) — users can turn their PyTorch module into a unit-scaled model.

We illustrate this in the notebook, training the following models:

from nanoGPT.model import GPT from notebook_utils import config, train from unit_scaling.transforms import simulate_fp8, unit_scale gpt = GPT(config)  # model unchanged from original nanoGPT fp8_gpt = simulate_fp8(gpt) unit_scaled_fp8_gpt = unit_scale(fp8_gpt) models = [gpt, fp8_gpt, unit_scaled_fp8_gpt] for model in models:     train(model)

unit scaling

Training the base model directly in FP8 causes a significant degradation. However, full accuracy is recovered by using unit scaling.

This one-line transform can be applied to arbitrary PyTorch models, with negligible overhead when used with torch.compile.

Implementing unit scaling

The one-line automatic unit_scale() transform is an experimental feature. We recommend most users implement unit scaling manually, in the following way.

Consider this common approach to importing PyTorch modules/functions:

from torch import nn from torch.nn import functional as F

In this setting unit scaling can be applied by first adding:

import unit_scaling as uu from unit_scaling import functional as U

and then replacing the letters nn with uu and F with U, for those classes and functions to be unit-scaled. For example:

class UnitScaledMLP(nn.Module):     def __init__(self, d: int) -> None:         super().__init__()         self.linear_1 = uu.Linear(d, d * 4)  # Changed `nn` to `uu`         self.linear_2 = uu.Linear(d * 4, d)  # Changed `nn` to `uu`     def forward(self, x: torch.Tensor) -> torch.Tensor:         x = self.linear_1(x)         x = U.gelu(x)  # Changed `F` to `U`         return self.linear_2(x)

There are a few additional considerations required to make unit scaling work properly, which are covered in our User Guide. Particular care should be taken to correctly scale skip/residual additions and loss functions.

Trying out the library

Unit scaling can be installed with:

pip install git+https://github.com/graphcore-research/unit-scaling.git

unit_scaling is a new library and (despite our best efforts!) we can't guarantee it will be bug-free or feature-complete. We're keen to assist anyone who wants to use the library and help them work through any problems.

Please reach out through our community of developers using our Slack channel or raise a GitHub issue.