<img height="1" width="1" style="display:none" src="https://www.facebook.com/tr?id=145304570664993&amp;ev=PageView&amp;noscript=1">
Papers wide (1)

Jun 06, 2024

xLSTM, Schedule-Free Optimizers, Multi-token prediction: POTM May

Written By:

Douglas Orr, Luke Prince, Luka Ribar

Join the IPU conversation

Join our Graphcore community for free. Get help and share knowledge, find tutorials and tools that will help you grow.

Join on Slack

May is always a eventful time of year for ML researchers, with final ICML paper decisions and ICLR taking place in early May, and NeurIPS submission deadlines closing the month. As ever, arxiv submissions continue to grow!

This month we take a look at three papers exploring new techniques to challenge the mainstream large-scale pretraining setup: transformers trained with next-token prediction optimized with Adam/AdamW.

The first paper, xLSTM, is a long-awaited deep dive into Sepp Hochreiter’s new, improved RNN architecture, nearly 30 years after the original LSTM was published. Drawing inspiration from linear attention, the authors demonstrate scaling comparable to transformers up to 1.3B parameters.

We then take a look at Schedule-Free optimizers from a team at FAIR. The authors propose a new class of optimizers that require no finicky learning rate scheduling. By replacing gradient momentum terms in standard optimizers with parameter averages, the authors show faster convergence than scheduled optimizers on a wide battery of small-scale deep learning tasks.

A further paper from FAIR extends the standard pretraining setup for large language models from next-token to multi-token prediction. This particularly seems to improve performance for larger models and offers a natural choice of model to use for speculative sampling to accelerate inference.

Here’s our summary of this month’s chosen papers:

xLSTM: Extended Long Short-Term Memory

Authors: Maximilian Beck, Korbinian Pöppel, et al. (NXAI, Johannes Kepler University Linz)

The key idea

Recurrent neural networksm, based on Long Short-Term Memory units were the backbone of NLP models before the advent of the now-ubiquitous transformer. This work seeks to close the gap between LSTM and transformer in the crucial model-scaling regime of LLMs. They do this by extending the LSTM in two ways to create sLSTM and mLSTM, then incorporating these layers into a deep residual architecture, called xLSTM.

POTM 1 (1)

Their method

We’ll focus on the mLSTM variant, as the sLSTM variant is omitted from many of the best-performing models in their results. I think the best way to understand the architecture is to stare at a wall of maths for a while:

POTM2 (1)

To give an intuition for this, there’s:

  • Inputs 𝐱 and parameters 𝐖𝐪,𝐤,𝐯,𝐨, 𝐛𝐪,𝐤,𝐯,𝐨, 𝐰𝐢,𝐟, 𝐛𝐢,𝐟.
  • Six linear + activation ops, depending only on the inputs: 𝐪,𝐤,𝐯,𝑖,𝑓,𝐨. The 𝑓 (forget) and 𝐨 (output) gates have sigmoid activation, giving outputs in the range [0,1], but 𝑖 (input) has an exponential activation. 𝐪,𝐤,𝐯 are linear.
  • A “cell” 𝐂: a decayed and weighted sum of 𝐯𝐤 (which I’ll call KV mapping) over time. At each step, the state is decayed according to the forget gate 𝑓 and the KV mapping is weighted according to the input gate 𝑖. The cell maps queries to values by matching them against keys.
  • A normalizer 𝐧: similar, but sums just 𝐤 instead of KV mapping.
  • An output 𝐨, the inner product of query 𝐪 and cell, divided by the magnitude of the inner product of 𝐪 and normaliser, and multiplied by the output gate.

Like softmax dot product self-attention, this involves a normalised sum of exponentials; a key difference is that the input to exp depends only on the “source” (key, value), not on the “target” (query). It bears some similarities to linear attention, Mamba and RWKV, permitting a parallel scan over the inputs since time dependency is linear. It retains the RNN’s advantage of summarising the context in a fixed-size representation, 𝐂, for efficient autoregressive inference.

In the xLSTM architecture, this is used in a custom residual block that performs positionwise up projection before the multi-headed mLSTM.

Results

Downstream results for LLMs of up to 1.3B parameters, trained on 300B SlimPajama tokens:

POTM3 (1)

(I haven’t been able to confirm if these are zero-shot or few-shot results.) Here, xLSTM[1:0] uses only the mLSTM layer described above, while xLSTM[7:1] includes 7 mLSTM layers per 1 sLSTM layer. These results appear to demonstrate the sufficiency of mLSTM for LLMs. The paper also includes a helpful set of ablations and synthetic tasks.

Takeaways

It’s refreshing to see non-transformer LLMs trained at scale, and that the xLSTM architecture appears competitive with transformers. More research could help us understand the benefits of these alternatives, and whether the scaling properties are robust.

Full paper: xLSTM: Extended Long Short-Term Memory

The Road Less Scheduled

Authors: Aaron Defazio, Xingyu (Alice) Yang, et al. (FAIR at Meta)

The key idea

Deep learning practitioners use often use two key hacks to make optimisation of deep neural networks work in practise:

  1. Learning rate schedules
  2. Weight averaging for evaluation.

Here the authors propose a principled approach that replaces estimates of first-order gradient moments with an averaged parameter state to adapt commonly used optimisers to avoid the need for either of these hacks with no overhead.

POTM4 (1)

Their method

We’ll present scheduled and schedule-free AdamW side-by-side, identify key differences, and explain how they are motivated.

Algorithm comparison

Given:

  • initial parameter state 𝑥1,
  • learning rate 𝛾,
  • weight decay 𝜆,
  • warmup steps 𝑇𝑤𝑎𝑟𝑚𝑢𝑝,
  • AdamW hyperparameters (𝛽1, 𝛽2, 𝜖)

We compute:

POTM 5

Let’s go through line by line:

  • Initialisation: Standard scheduled AdamW initialises gradient moment variables 𝑧 and 𝑣 at 0. Schedule-free AdamW stores the second gradient moment variable 𝑣, and 𝑧 now represents a raw un-averaged parameter state, and is initialised to be the same as an averaged parameter state 𝑥𝑡
  • Optimizer state updates (Lines 1-4): Standard scheduled AdamW computes gradients given current parameter state 𝑥𝑡 (Line 1) and update moments as an exponential moving average with temperatures 𝛽1 and 𝛽2 (Lines 2-3), and correct moment estimation bias (Line 4). Schedule-free AdamW first computes an interpolation 𝑦𝑡 between the raw 𝑧𝑡 and averaged 𝑥𝑡 parameter state (Line 1). We then compute gradients at this interpolated point (Line 2) and update the second moment (Line 3), and correct moment estimation bias (Line 4).
  • Parameter state updates (Lines 5-8): Scheduled AdamW first determines learning rate coefficients given warmup and decay schedule (Lines 5-7), before applying the standard update rule using moments 𝑧𝑡, 𝑣𝑡 with weight decayed from 𝑥𝑡 (Line 8). Schedule-free AdamW likewise applies a warmup to the learning rate (Line 5), then updates the non-averaged parameter state 𝑧𝑡 using gradient estimate 𝑔𝑡, second moment 𝑣𝑡, and decays from interpolated weights 𝑦𝑡 (Line 6). We then update our weighted average of parameters 𝑥𝑡 with weights computed to discount parameters during warmup (Lines 7-8).

What motivates these changes?

Previous work by the same group illustrated a connection between learning rate schedules and Polyak-Ruppert parameter averaging, a theoretically optimal technique for ensuring convergence in stochastic optimisation. Polyak-Ruppert parameter averaging is simple to compute (effectively just line 6-8 of our schedule-free algorithm), but appears to perform worse than cosine decay schedules in practice.

The authors propose combining Polyak-Ruppert averaging with Primal averaging. In Primal averaging, we evaluate gradients at a slow moving average parameter value rather than a fast moving immediate parameter value (standard practice). Likewise, Primal averaging on its own also appears to perform worse in practice as parameters change too slowly.

The combined solution is to effectively try to get the Primal average parameters to move a bit faster, by interpolating them with a Polyak-Ruppert average. This interpolated parameter is our 𝑦𝑡 term computed on Line 1. Given that when 𝛽1=1 is pure Primal averaging, and 𝛽1=0 is pure Polyak-Ruppert averaging, the authors’ recommended 𝛽1=0.9 is still pretty close to Primal averaging.

Two other changes appear to be less theoretically motivated: using 𝑦𝑡 for decaying weights (rather than 𝑥𝑡 or 𝑧𝑡), and Polyak-Ruppert averaging coefficients 𝑐𝑡 that discounts parameter states visited during learning rate warmup. Warmup-free optimisers are a step too far it seems…

Results

The authors test schedule-free optimiser on a battery of different small models of different types (Transformers, RNNs, CNNs, GNNs, Recommenders), different datasets and objective functions, In each case they show comparable convergence as carefully tuned learning rate schedules, with faster training dynamics in many cases.

POTM 6 (1)

Takeaways

As hacks go, learning rate schedules are an enduring one. Given the drastic effect they can have on your model performance when implemented in a training pipeline you omit them at your peril. However, they never seemed particularly well motivated other than for their empirical effect. This looks like a step in the right direction for hack-free optimisation in deep learning.

Full paper: The Road Less Scheduled

Better & Faster Large Language Models via Multi-token Prediction

Authors: Fabian Gloeckle, et al (FAIR at Meta)

The key idea

Large language models are usually trained using the next-token prediction loss. The authors propose training the model to predict multiple tokens at a time instead, while still generating a single token at a time at inference as usual. By training models up to 13B parameter in size, they show that this can lead to models with better performance, particularly at coding tasks.

POTM 7 (1)

Multi-token prediction: Each output head predicts a token (4-token prediction shown), while only the first head is employed during inference. The training scheme improves performance on MBPP coding task as models get larger.

Their method

In order to enable multi-token prediction, the authors propose a simple modification to the standard transformer architecture. The final output embedding is fed into 𝑛 parallel output heads, each a single standard transformer layer. This effectively means that the final transformer layer is replaced by 𝑛 parallel transformer layers. The outputs of each head are then passed through a shared unembedding projection, generating a probability distribution over the whole vocabulary for each head. During training, each head is then trained to predict one of the next 𝑛 tokens for each training example. In order to minimise maximum memory usage during training, the forward/backward passes on each head are performed sequentially (Figure 2).

POTM 8 (1)

During inference, all but the output of the first head are discarded, and tokens are generated one-by-one as with the standard transformer architecture. However, multiple-token prediction can be used to speed-up inference using self-speculative decoding, i.e. by using the 𝑛 generated tokens as an initial sequence draft, and validating the sequence with just the next-token head in parallel.

Results

  • Improvement was only observed at scale - improvements were strongest for the largest models.
  • Observed 3x speedup using speculative decoding with 7B 4-token prediction model.
  • Optimal 𝑛 was empirically found to be 4 for token-based models, and 8 for byte-based models.
  • Unlike coding tasks, on natural language tasks the performance does degrade compared to the next-token baseline.

Takeaways

The results of the paper are promising, as they show multi-token prediction can indeed lead to improved performance at scale, particularly at coding tasks, while at the same time providing a more suitable drafting model for speculative-sampling inference. The results hint at the possible benefits of teaching the model to “plan ahead” compared to the standard next-token prediction, and may lead to exciting alternatives to the widely-adopted token-by-token generation.

Full paper: Better & Faster Large Language Models via Multi-token Prediction

 

Discover more on the Graphcore Research team's Github, and subscribe to the Papers of the Month newsletter.