Explaining LION-Chunk for Balancing Memory-Speed Tradeoffs During Inference
Since we have now established the LION framework, which maps Full Linear Attention into a bi-directional RNN in Part II of this series, a key question arises:
Given that RNNs are efficient and Attention is fast, can we strike a balance between them?
For causal Transformers like DeltaNet
The key idea of chunking is that instead of processing the entire sequence of length $L$, we divide it into $N$ subsequences of length $C$, where $N \times C = L$.
To achieve this, we start with the Full Linear Attention formulation:
we first chunk the queries, keys and values into submatrices
\[\mathbf{Q}_{[i]} , \mathbf{K}_{[i]}, \mathbf{V}_{[i]} \in \mathbb{R}^{C \times d}\]Now, given the form \((\mathbf{A} \odot \mathbf{M})\), where \(\mathbf{A} = \mathbf{Q} \mathbf{K}^\top\) we can construct the chunkwise form in four parts
using these chunked matrices we shape the full linear atteniton in chunk form as bellow:
LION Chunk
\[\begin{aligned} \mathbf{A}_{[ij]} & = \mathbf{Q}_{[i]}\mathbf{K}_{[j]}^\top \odot \mathbf{M}_{[ij]}, \\ \mathbf{C}_{[ij]} &= \mathbf{C}_{[i(j-1)]} + \text{Sum}(\mathbf{A}_{[ij]}), \\ \mathbf{S}_{[ij]} & = \mathbf{S}_{[i(j-1)]} + \mathbf{A}_{[ij]} \mathbf{V}_{[j]}, \\ \mathbf{Y}_{[i]} & = \frac{\mathbf{S}_{[iN]}}{\mathbf{C}_{[iN]}} \end{aligned}\]
where $\text{Sum}$ operations applies summation over the row of the input matrix. And $\mathbf{M}_{[ij]}$ corresponds to a submatrix of the full maks $\mathbf{M}$ at chunk $ij$ like:
\[\mathbf{M}_{[ij]} = \mathbf{M}_{iC+1:i(C+1),jC+1:j(C+1)} \in \mathbb{R}^{C \times C}.\]Let’s start with an example, chunking the Attention matrix $\mathbf{A}$ for a sequence of $L=9$ with $C=3$ chunk size in detail below:
Chunking simply involves computing the queries and keys for each boxed sub-matrix, as illustrated for the upper, lower, and diagonal chunks. For every Attention matrix chunk $[ij]$, the computation follows the same pattern, multiplying the corresponding queries and keys for that chunk.
But does the same approach apply to Selective and Fixed masks?
In reality, chunking the Attention mask is slightly different and even more critical than chunking Attention itself due to its unique structure. Below, we provide a detailed explanation of how to chunk the Attention mask for LION-D and LION-S.
🚀 Note: The chunking visualization and details of this part are exclusively on the blogpost version.
Let’s start with the decay mask, as it is simpler and easier to visualize. For LION-D, the final mask is a Toeplitz mask constructed using the scalar decay factor $\lambda$. We can visualize how the mask is structured.
The full mask of LION-D (or full RetNet mask) is constructed simply by the submatrix of $\Gamma$, which is a Toeplitz matrix itself. Regardless of where the chunk is located, whether in the upper or lower part of the mask matrix $\mathbf{M}$, it retains the same property of being a fraction of the Toeplitz matrix $\Gamma$ as bellow:
\[\mathbf{M}_{[ij]} = \Gamma \lambda^{|i-j|}\]A pytorch implementation for LION-D Chunk Mask is provided below:
def mask_decay_partial(a, length, start, end):
idx = torch.arange(length, device=a.device)
i, j = torch.meshgrid(idx, idx[start:end], indexing="ij")
e = torch.abs((i - j)).float().view(1, 1, length, len(idx[start:end]))
m = torch.sigmoid(a).view(1, -1, 1, 1) ** e
return m
The full mask of LION-S is more tricky than LION-D since the upper lower and the diagonal part of the mask are shaped differently:
Let’s visualize LION-S mask as well:
For example, the chunk [1,3] has only the cumulative decay factors multiplied from the beginning up to the last three sequence elements, while the chunk [3,1] has only the decay factors multiplied from the end up to the first three sequence elements. This is the reason for using the matrices $\mathbf{L}^F$ and $\mathbf{L}^B$ to compute the cumulative products of the decay factors, progressing from the beginning to the end of the sequence and in reverse which can be created simply by L^F = cumprod(a)
and L^B = cumprod(flip(a))
.
def mask_forward(tensor, chunk_index, chunk_length):
cumprod = torch.clamp(tensor.cumprod(dim=-1), 1e-6)
a = (
cumprod.unsqueeze(-1)
/ cumprod.unsqueeze(-2)[
..., chunk_index * chunk_length : (chunk_index + 1) * chunk_length
]
)
return torch.tril(a, diagonal=-chunk_index * chunk_length)
def mask_backward(tensor, chunk_index, chunk_length):
cumprod = torch.clamp(tensor.cumprod(dim=-1), 1e-6)
a = cumprod.unsqueeze(-1)[
..., chunk_index * chunk_length : (chunk_index + 1) * chunk_length, :
] / cumprod.unsqueeze(-2)
return torch.triu(a.transpose(-1, -2), diagonal=-chunk_index * chunk_length)
def mask_selective_partial(vec, chunk_index, chunk_length):
b, h, l = vec.shape
a_for = create_matrix_from_tensor_forward(
torch.cat((torch.ones_like(vec[..., :2]), vec[..., 1:-1]), dim=-1),
chunk_index,
chunk_length,
)
a_back = create_matrix_from_tensor_backward(
torch.cat((torch.ones_like(vec[..., :1]), vec[..., 1:]), dim=-1),
chunk_index,
chunk_length,
)
i = torch.diag_embed(
torch.ones((b, h, l - chunk_index * chunk_length)),
offset=-chunk_index * chunk_length,
)[..., : a_for.shape[-1]]
return a_for + a_back - i.to(a_for.device)
Now that we have all elements in place let’s see how these models are working in practice on real-world datasets for masked language modeling and image classification.
In the final part of this series, we present the advantages of using LION compared to other methods for training SSMs or Linear Transformers.
We also present the trade-offs for different LION 🦁 models and compare them with other well-known SSMs and Softmax Transformers.