Perceived AR: General long-context autoregressive generation


Perceived AR: General long-context autoregressive generation

In recent years, autoregressive transformers have brought a steady stream of breakthroughs in generative modeling. These models create each element of a sample—the pixels of an image, the characters of text (usually in “token” blocks), the samples of an audio waveform, and so on—by predicting one element at a time. When predicting the next item, the model can look back on those that were created previously.

However, each layer of a Transformer becomes more expensive as more elements are used as input, and practitioners can only afford to train Deep Transformers with sequences no longer than about 2,048 elements. Therefore, most Transformer-based models ignore anything beyond the recent past (about 1,500 words, or 1/6 of a small image) when making a prediction.

In contrast, our recently developed Perceiver models deliver excellent results on a variety of real-world tasks with up to about 100,000 elements. Perceivers use cross-attention to encode inputs into a latent space, decoupling the computational requirements of the input from the model depth. Perceivers also spend a fixed cost at almost every level, regardless of the input size.

While latent space encoding processes all elements in a single pass, autoregressive generation assumes the processing is element by element. To address this problem, Perceiver AR proposes a simple solution: sequentially align the latent elements with the last elements of the input and carefully mask the input so that the latent elements only see earlier elements.

Perceiver AR maps an input sequence (Perceiver AR) onto a small latent space by cross-attention to produce a latent for each target token (three latents are shown, one for the target AR e.g End Of SConsequence). These latents are then processed by a deep stack of layers of self-awareness. Perceiver AR can be trained for end-to-end autoregressive generation while simultaneously using very long input sequences.

The result is an architecture (see above) that handles inputs up to 50 times longer than standard transformers, while being as widely (and essentially as easily) deployed as standard decoder transformers.

As context length or model size increases, so does the computational effort required to train a model. We can quantify the computational budget for different models by measuring their speed on real hardware (steps per second on TPUv3) as input context length and model size increase. Unlike other generative models like Transformer or Transformer-XL, Perceiver AR decouples input context length from model depth, allowing us to easily deploy the deep models needed to model long sequences on current-gen TPUs or GPUs.

Perceiver AR scales significantly better with size than both standard Transformers and Transformer XL models over a range of real world sequence lengths. This property allows us to create very effective long-context models. For example, we find that a 60-layer Perceiver-AR with a context length of 8192 outperforms a 42-layer Transformer-XL on a book-length generation task while running faster in real wall-clock time.

On standard benchmarks for long-context images (ImageNet 64×64), speech (PG-19), and music (MAESTRO), Perceiver AR delivers state-of-the-art results. Increasing input context by decoupling input size from computational budget leads to several intriguing results:

  • The computational budget can be adjusted at evaluation time, allowing us to spend less and degrade quality smoothly, or spend more on improved generation.
  • Greater context allows Perceiver AR to outperform Transformer-XL, even if the same is spent on computing power. We find that greater context leads to improved model performance, even at an affordable scale (~1 billion parameters).
  • Perceiver AR sample quality is much less sensitive to the order in which items are generated. This makes Perceiver AR easy to apply to shots that don’t have a natural left-to-right order, such as B. Data such as images, the structure of which includes more than one dimension.

Using a data set of piano music, we trained Perceiver AR to create new pieces of music from scratch. Since each new note is predicted based on the entire previous note sequence, Perceiver AR is able to produce pieces with a high degree of melodic, harmonic and rhythmic coherence:


Learn more about using Perceiver AR:

  • Download the JAX code for training Perceiver AR on Github
  • Read our paper on arXiv
  • Watch our Spotlight presentation at ICML 2022

Check out the Google Magenta blog post with more music!

You May Also Like