Explaining the Full Linear Attention paradigm for bi-directional sequence modeling
Transformers with Linear Attention enable fast and parallel training. Moreover, they can be formulated as Recurrent Neural Networks (RNNs), for efficient linear-time inference. While extensively evaluated in causal sequence modeling, they have yet to be extended to the bi-directional setting. We introduce the LION framework, establishing new theoretical foundations for Linear Transformers in bi-directional sequence modeling. LION constructs a bi-directional RNN equivalent to full Linear Attention. This extends the benefits of Linear Transformers: parallel training and efficient inference into the bi-directional setting.
Existing memory-efficient bi-directional models employ more than 2x the training time of a Transformer. Our Linear Attention framework benefits from memory-efficient inference while maintaining the Transformer training speed.
Task | π¦-π₯ | π¦-D | π¦-S | Hydra | Vim |
---|---|---|---|---|---|
Vision | $\times 0.73$ | $\times 1.39$ | $\times 1.46$ | $\times 2.51$ | $\times 10.86$ |
MLM | $\times 0.95$ | $\times 1.10$ | $\times 1.32$ | $\times 3.13$ | β |
Using LION, we cast three Linear Transformers to their bi-directional form:
By replacing the attention block with LION (-οΈβπ₯, -D, -S), we achieve performance on bi-directional tasks that is comparable to Transformers and State-Space Models (SSMs) while improving training speed.
Recently, Transformers with Linear Attention
We are curious to explore whether Linear Attention Transformers, including the vanilla Linear Transformer
Letβs break this down with three key questions:
Given that Linear Transformers can be formulated as RNNs, offering efficiency benefits during inference and enabling parallel training for causal sequence modeling, can they also provide similar advantages for bi-directional processing? If so, what would the parallel form be, and how would the equivalent bi-directional RNN be structured?
Can simple Linear Transformers, like Linear Transformer
While bi-directional SSMs are performant, they tend to be difficult and slow to train compared to Transformers with Full Attention (e.g., ViT
Letβs start with Linear Attention Recurrence:
\[\begin{aligned} & S_i = S_{i-1} + k_i v^\top_i, \quad z_i = z_{i-1} + k_i, \\ & Scaled: y_i = \frac{q^\top_i S_i}{q^\top_i z_i}, \quad Non-Scaled: y_i = q^\top_i S_i \\ \end{aligned}\]Above is the RNN form of the Linear Attention which has the parallel form of:
\[\mathbf{Y} = Scale \left(\mathbf{Q} \mathbf{K}^\top \odot \mathbf{M}^C \right)\]and the mask \(\mathbf{M}^C\) is a lower triangular \(C\)ausal mask. Causal Linear Transformers are a class of models introduced following the development of the original Linear Transformer as shown above
Here, \(\boldsymbol{\Lambda_i}\) and \(\gamma_i\) are decay factors introduced after the Linear Transformer to enhance the performance and \(\star\) denotes an associative operator which depends on the specific model. (Spoiler alert β οΈ: the family of Linear Transformers has strong connections to SSMs, as explored in works like Deltanet
The first goal is to extend the Causal Linear Attention parallel form
\[\mathbf{Y} = \text{Scale} \left(\mathbf{Q} \mathbf{K}^\top \odot \mathbf{M}^C \right)\]to a Scaled and Masked Full Linear Attention mechanism.
The first step is quite simple: the Masked and Scaled Attention can naturally take the following form, as suggested by its name:
Full Linear Attention
\[\mathbf{Y} = \text{Scale} \left(\mathbf{Q} \mathbf{K}^\top \odot \mathbf{M} \right)\]
The important part is how to well define the matrix \(\mathbf{M}\). A natural choice is to extend the causal mask \(\mathbf{M^C}\), where the causal mask between tokens \(i,j\) is given by \(\mathbf{M}^C_{ij} = \lambda_{j+1} \lambda_{j+2} \dots \lambda_i\), representing the product of all selective scalers between \(i\) and \(j\). In the bi-directional case, the full mask should preserve this desirable property. One can interpret the mask entries as a relative positional encoding between two tokens taking the following form:
\[\begin{aligned} \mathbf{M}_{ij} = \begin{cases} \Pi_{k=j}^{i-1}{\lambda_k}, & i > j \\ 1 & i=j\\ \Pi_{k=i+1}^{j}{\lambda_k}, & i < j. \end{cases} \end{aligned}\]To recap, the full output of Full Linear Attention can be presented as:
\(\mathbf{Y} = Scale \left( \underbrace{\left( \renewcommand*{\arraystretch} \begin{array}{ccccc} \mathbf{q}_1^{\top}\mathbf{k}_1 & \mathbf{q}_1^{\top}\mathbf{k}_2 & \cdots & \mathbf{q}_1^{\top}\mathbf{k}_L \\ \mathbf{q}_2^{\top}\mathbf{k}_1 & \mathbf{q}_2^{\top}\mathbf{k}_2 & \cdots & \mathbf{q}_2^{\top}\mathbf{k}_L\\ \vdots & \vdots & \ddots & \vdots \\ \mathbf{q}_L^{\top}\mathbf{k}_1 & \mathbf{q}_L^{\top}\mathbf{k}_2 & \cdots & \mathbf{q}_L^{\top}\mathbf{k}_L\\ \end{array} \right)}_{\hspace{1mm} \mathbf{A} = \mathbf{Q} \mathbf{K}^{\top}} \odot \underbrace{ \left( \renewcommand*{\arraystretch} \begin{array}{ccccc} 1 & \lambda_2 & \lambda_2 \lambda_3 & \cdots & \lambda_2 \cdots \lambda_L \\ \lambda_1 & 1 & \lambda_3 & \cdots & \lambda_3 \cdots \lambda_L \\ \lambda_2 \lambda_1 & \lambda_2 & 1 & \cdots & \lambda_4 \cdots \lambda_L \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \lambda_{L-1} \cdots \lambda_1 & \lambda_{L-1} \cdots \lambda_2 & \lambda_{L-1} \cdots \lambda_3 & \cdots & 1 \\ \end{array} \right) }_{\hspace{1mm} \mathbf{M}} \right) \left( \renewcommand*{\arraystretch} \begin{array}{c} \mathbf{v}_1^\top \\ \mathbf{v}_2^\top \\ \mathbf{v}_3^\top \\ \vdots \\ \mathbf{v}_L^\top \\ \end{array} \right)\)
The equation above represents the Full Linear Attention in parallel form. Now that we have established Full Linear Attention for bi-directional sequence modeling, itβs time to derive its equivalent bi-directional RNN.
Question: Is it worth training with Full Attention on bi-directional tasks considering it has quadratic complexity with sequence length \(O(L^2)\)?
The answer is yes! Unlike causal language modeling, for bi-directional tasks such as Vision ($L=196$) and Masked Language Modeling (MLM) ($L=128$), sequence lengths used in practice are relatively short. This means that we can usually fit Full Attention in memory enalbing higher throughput without a significant trade-off in complexity.
We believe that architectures designed for causal tasks can really benefit from modifications to adapt them to the bi-directional domain.
We introduce our framework, LION, which derives an equivalent bi-directional RNN for Full Linear Attention.
Within this framework, we demonstrate how different Linear Transformers can be extended to their bi-directional counterparts.
We explore the construction of stable masks \(\mathbf{M}\), enabling models using LION to TRAIN IN PARALLEL using Full Attention and INFER EFFICIENTLY like an RNN.
Finally, we introduce a chunkwise parallel variant of LION to balance recurrence and parallelism π.
Continue reading to Part II - Bi-directional RNN
Acknowledgement: We appreciate Albert Gu and Tri Dao for their insightful blog posts, which have been helpful in shaping our own.