Advanced Transformer Architectures

Efficiency

In this session, our readings cover:

Required Readings:

Advancing Transformer Architecture in Long-Context Large Language Models: A Comprehensive Survey

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

JAMBA

More readings:

Mamba: Linear-Time Sequence Modeling with Selective State Spaces

Efficient Memory Management for Large Language Model Serving with PagedAttention

Attention Mechanisms in Computer Vision: A Survey

State Space Model for New-Generation Network Alternative to Transformers: A Survey

Motivation

Pros and Cons of Attention

Hence, We need a model that not only requires less computing cost but also is still able to capture long-range dependencies while maintaining high performance.

That’s what State Space Model (SSM) wants to solve.

Formulation of SSM

SSM is a commonly used model in control theory and is used in Kalman simulation and hidden Markov models. Its basic formulation is shown in the figure below.

Normally, we would omit the parameter D (assume D=0 becuase the term Du can be viewed as a skip connection and is easy to compute). So a more common formulation we would see in most state space model would be as:

Discretization

As a continuous system, it is hard for SSM to be used in modern deep learning algorithm. In practice, we always deal with discrete data, such as text. This requires us to discretize the SSM, transforming our continuous parameters A, B, C into discrete parameters $\hat{A}, \hat{B}$ using zero-order hold rule (ZOH) as shown in Figure below. Readers can refer to the paper for detailed derivation.

In conclusion, the discretized version of SSM is like:

Convolutional Form

Unlike RNN, SSM here doesn’t have non-linear functions. So we can try to expand $y_t$ and surprisingly find SSM can be written in convolutional form.

Looking at the result of the expansion above, we can see that the coefficient of each $x_t$ can be extracted out and write a convolutional kernel:

Hence, we can write our SSM formulation as:

It’s easy to find that SSM is very similar to RNN. Comparing the formulation of SSM and RNN below, we can find the main reason why RNN can’t be written in convolutional form and thus can’t be trained efficiently is the non-linear funciton $f$.

Structured State Space Model (S4)

To solve this problem, HiPPO matrices is introduced which combines the concepts of Recurrent Memory and Optimal Polynomial Projections, thus can significantly improve the performance of recursive memory.

In practice, we would use HiPPO matrix to initial like matrix A.

Note the “Structured” comes from the HiPPO matrix. And we usually can the vallila SSM with HiPPO matrix :S4 model in short which will be seen in most SSM related papers.

From S4 to Mamba (S6)

The problem of S4:

Mamba makes these parameters vary based on the input, like the formulation below:

By doing so, model has the ability to focus on certain words, like the Figure shown below.

Parallization of Mamba

Mamba is able to solve this problem through parallel scan.

Parallel Scan Whether an operation can be done in parallel depends on associative property. Mamba’s recurrence was very similar to a scan algorithm, also known as a prefix sum.

We can verify its associative property with a new variable k:

Figure below shows how parallel scan works. We can pick any vertical line and start from the top of this line and move to the bottom, tracing each addition back to the array’s first few items. By the time we reach the bottom, we should have the sum of all items to the left of this line.

Variations of SSM

Language Modeling

S4+++: -State Memory Relay. -Integrate complex dependency bias via an interactive cross-validation mechanism.

Voice Task

DP-Mamba -Bidirectional Dependency Modeling: Simultaneously models both short-term and long-term forward and backward dependencies of speech signals. -Selective State Space: Enhances model capability through a selectively utilized state space. -Performance: Achieves comparable results to the dual-path Transformer model Sepformer.

SP-Mamba:

Variations in Computer Vision

VMamba VMamba uses linear complexity to capture the full range of sensory fields, introduces traversal of spatial information across scan blocks, and converts non-causal visual images into ordered patch sequences.

Vision Mamba

The Vim model divides the input image into chunks and then projects the chunks into tokens at the begining. These tokens are then fed into the Vim encoder. For tasks like ImageNet classification, an additional learnable classification token is added to the sequence of token labels (this labels are used consistently in this way from the beginning of heavy BERT). Unlike the Mamba model used for modeling text sequences, the Vim encoder processes the token sequence in both the forward and reverse directions.

And the Vim encoder will be shown in the figure below

Mamba Variations in different Task

Variations in Graph

GraphS4mer: Using the S4 architecture to capture long-range dependencies and includes a dynamic graph structure learning layer for spatial correlations.

GMN: Based on selective State Space Models, tackling the limitations of traditional GNNs in capturing long-range dependencies and computational efficiency.

Variations in Multi-modality and Multi-media

Variation for Time Serires

TimeMachine Purpose: Addresses challenges in long-term time-series forecasting (LTSF). Key Challenges:

Advancing Transformer Architecture in Long-Context Large Language Models: A Comprehensive Survey

Introduction to Long-Context LLMs

Challenges and Research Directions in Long-Context LLMs

Contributions of this Survey

Section 2: Overview

Preliminaries of Neural Language Modeling

Limitations of Transformer Architecture in Handling Long Contexts

Roadmap of Enhancements for Long-Context Capabilities in LLMs

Section 3: Efficient Attention Mechanisms

Local Attention

Hierarchical Attention

Approximated Attention

IO-Aware Attention

Section 4: Long-Term Memory

Because of in-context working memory, the Transformer architecture often struggles with capturing long-term dependencies. The researchers propose two main avenues to address this challenge: (1) Internal MemoryCache; (2) External MemoryBank.

Section 4: Long-Term Memory

Internal MemoryCache

For Internal MemoryCache, there are different types:

External MemoryBank

For External MemoryBank, there are different types:

In summary, Internal MemoryCache trades space for time by using caching mechanisms to reduce computation. However, after model training is completed, it is difficult to update the internal knowledge, which is why such methods are rarely used nowadays. Instead, the External Memory Bank method is mainly used.

Section 5: Extrapolative PEs

The meaning of PEs is Extrapolative Positional Encodings. Current PEs play the undeniable role in length generalization in more general scenarios.

Section 6: Context Processing

There are three different strategies:

Section 7: Miscellaneous Solution

The miscellaneous solution talked in the part are not be exhaustive or specific to Transformer-based models. Many of these techniques are applicable universally to any model equipped with deep neural networks, albeit particularly crucial for large-scale LLMs. Some solutions are as follows:

Section 8: Evaluation Necessity & Optimization Toolkit

The researchers explore evaluation necessities for assessing long-context capabilities of LLMs, including datasets, metrics, and baseline models. And they investigate popular optimization toolkits, such as libraries, frameworks, and compilers, to enhance LLM efficiency and effectiveness during development.

For Datasets, detailed information on each dataset is available in Table 1, covering language, task types, length statistics, quality, splits, count and format.

For Metrics, Table 2 provides a summary of nine categories of general evaluation metrics commonly employed across ten NLP task types, encompassing language modeling, question answering, summarization, math solving, code generation, and open-ended writing, among others.

For Baselines, Table 3 gathers a list of pretrained/finetuned LLMs commonly, serving as baselines for evaluating long-context capabilities across various downstream tasks.

For Toolkit, Table 4 collects a diverse array of valuable toolkits to optimize the efficiency and effectiveness of LLMs across their development lifecycle.

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Motivation

Transformers based on current attention architectures do not perform well when context length is beyond a threshold. The first motivation of this work is that designing a transformer architecture that can model longer sequence data has the following potential applications:

  1. In NLP tasks, a large context allows the LLM to read books, plays and instruction manuals before generating a response.
  2. In computer vision, higher resolution images require the attention architecture to be capable of handling longer sequences. In the case of high resolution MRI as shown in the slide below, if the transformer is able to generate a high resolution image, it can improve the performance of downstream tasks such as pathology detection or tissue segmentation.
  3. Other types of natural sequence data such as time-series data, audio data, video data and medical imaging also require the transformer to perform well on much longer sequences.

The second motivation of the work is that the attention computation is bottlenecked by the I/O from High Bandwidth Memory (HBM), which is large in size but relatively slow compared to SRAM. As an example, A100 offers 40GB HBM or 80GB HBM, but its bandwidth is only 1/10 of that of SRAM. The standard attention computation, as shown in the slide below, however, requires numerous writing to and reading from HBM for intermediate values such as the attention matrix throughout the computation, which makes I/O from HBM the bottleneck of the attention computation.

FlashAttention Algorithm

FlashAttention is fast, memory efficient and an exact computation of attention. FlashAttention is I/O aware and aims to reduce number of times needed to read and write to HBM. It computes the attention block by block. When computing each output block, all four blocks from Q, K, V and Output can be stored in SRAM. So we do not need to store the intermediate values to HBM. In addition, the overall SRAM memory footprint depends only on block size and head dimension and is not related to length. Instead of the entire attention matrix, since only a block is calculated each time, it can also handle longer sequences.

Flashattention is based on safe softmax and online softmax, which are simpler methods that may help us understand flashattention. To avoid numerical overflow, safe softmax subtracts m from the exponent which is the max over all input x, so that the exponential in softmax is less or equal to zero and safe to compute.

Safe Softmax requires a total of three passes. The first pass iteratively calculates a local maximum of the softmax imput, using the result from the previous iteration. When the for loop ends, the result will be the global maximum over all x. The second pass iteratively calculates the denominator using the global maximum from the previous pass. The final pass calculates the softmax using the denominator and the global max.

The online softmax reduces the computation from 3 passes to 2 passes. When updating the denominator of softmax, If we replace the global max and use the local max at iteration i with a scaling factor, we can calculate the max and the denominator together in 1 pass.

FlashAttention aims to reduce the calculation to 1 pass, and outputs attention instead of softmax in the previous two algorithms. Attention requires an additional calculation: a matrix multiplication of softmax and value V to obtain the output O. FlashAttention perform such calculation by breaking down softmax into smaller softmax. Here in the slide below, the output is updated in two terms. The first term is the output computed from the previous iteration times a scaling factor. The second term can be considered as a small softmax times a row of V. By updating the output in this iterative manner, FlashAttention can further reduce the computation to 1 pass.

This only calculating one row of Q, V and one column of K each time. To make full use of the SRAM fast cache memory, we can treat many rows together as blocks, and calculate the attention block by block. We are using a largest block size that can fit four blocks of Q, K, V, O onto the SRAM. For a particular block from Q, we iterate through all blocks from K transpose and V, while maintaining two columns of max and denominator. After the iterations, the result will obtain an exact block of output. In this procedure, FlashAttention calculates the attention in a block by block manner.

Evaluation

When both training a BERT-large model on a single node, FlashAttention is demonstrated to require 15% less training time than Nvidia’s attention implementation.

When training GPT-2 small, compared to Megatron-LM, FlashAttention supports 4 times longer the context length, is still being 30% faster while achieving 0.7 better perplexity.

Being an exact attention implementation, FlashAttention is not only faster than PyTorch Attention, but also faster than OpenAI Sparse Attention, when the context length is less than 4096. It is slower than Linformer Attention, which is an approximation method using low-rank matrix. In terms of memory usage, it requires 2x less memory than Linformer Attention, and 20x less memory than Pytorch Attention.

Limitations and Future Directions

Compiling to CUDA. The current implementation requires writing a new CUDA kernel in low-level language use, and may not transfer to other GPU architectures. These limitations suggest a need to write attention algorithms in high-level language such as PyTorch.

IO-Aware Deep Learning. The IO-aware approach can be potentially extend to every layer in a deep network.

Multi-GPU IO-Aware Methods. The current algorithm is designed for a single GPU node and does not take data transfer across multiple GPU into consideration. The authors hope to inspire future work to design attention computation that is parallelizable across multiple GPUs.