Multi-Headed attention as matrix multiplication
Today I’ll walk through how to implement multi-headed attention with a series of matrix multiplications. Computing attention in this way is more efficient than iteratively visiting each input vector and computing its associated output. Matrix multiplication operations are highly optimized and are designed to be parallelzed across threads on the GPU.
To help us, I’ll also introduce two interpretations of matrix multiplication, both of which are super useful even outside the context of computing multi-headed attention.
Computing the attention weights
Suppose you had two sets of vectors of identical dimension \(\vec{a}_1, \vec{a}_2, ..., \vec{a}_n\) and \(\vec{b}_1, \vec{b}_2, ..., \vec{b}_m\) and you construct a matrix \(A\) by stacking the \(\vec{a}_i\) vectors as rows, and a matrix \(B\) by stacking the \(\vec{b}_i\) vectors as columns. You can then interpret the multiplication \(AB\) as a pairwise similarity matrix:
\[ \left(AB\right)_{i,j} = \vec{a}_i \, \cdot \, \vec{b}_j \]We can use this to help us compute our attention weights. Recall that in computing a self-attention layer, we iterate through each query vector, and compute the scaled dot product against all of the key vectors. For the ith query vector, we compute the weights with:
\[ w_j \propto \exp\left(\frac{\left(W^{Q}\cdot\vec{v}_i\right) \cdot \left(W^{K}\cdot\vec{v}_j\right)}{\sqrt{d}}\right) \]Note the \(\left(W^{Q} \, \cdot \, \vec{v}_i\right) \, \cdot \,\left(W^{K} \, \cdot \, \vec{v}_j\right)\) term. If we were able to compute all pairwise dot-product similarities, we would be well on our way to computing the attention weights. Lets compute the similarities with a few matrix multiplications:
\[ \begin{gather*} Q = IW^Q \\ K = IW^K \\ S = \frac{1}{\sqrt{d}} \cdot QK^T \end{gather*} \]\(S\) is a scaled pairwise dot-product similarity matrix (\(d\) being the dimension of the query/key embedding space), \(I\) is a matrix constructed by taking our input \(\vec{v}_i\) vectors as rows, and the matrices \(Q\) and \(K\) are the input vectors embedded into query and key space, respectively. That is, the ith row of \(IW^Q\) is \(W^Q\cdot \vec{v}_i\), and the ith row of \(IW^K\) is \(W^{K}\cdot\vec{v}_i\). Note that we need to take the transpose of \(K\) in computing \(S\) in order to stack the key vectors along columns rather than the rows.
Computing the similarity matrix isn’t enough; we need to turn this similarity matrix into attention vectors for each query exponentiating all of the entries and then normalizing each row to sum up to 1. This is the same as taking the softmax of the rows of the matrix.
\[ \begin{gather*} A = Softmax\left(S\right) \end{gather*} \]After taking this softmax, we have the property that the ith row of \(A\) is the attention weight vector for the ith input vector. Concretely, \(A_{i,j}\) tells us, when computing the output for input vector \(\vec{v}_i\), how much weight \(w_j\) to assign input vector \(\vec{v}_j\).
Computing the outputs
Now, we will leverage a second useful interpretation of matrix multiplication in order to use our attention matrix to compute output vectors. Suppose we had vectors \(\vec{a}_1, \vec{a}_2, ..., \vec{a}_n\) and \(\vec{b}_1, \vec{b}_2, ..., \vec{b}_m\), with the dimension of the \(\vec{a}\) vectors being \(m\). Construct a matrix \(A\) by stacking \(\vec{a}_i\) vectors as rows, and a matrix \(B\) by stacking the \(\vec{b}_i\) vectors also as rows. Note that this is different than the set up we had for the other matrix multiplication interpretation! The key difference is that we are stacking the \(\vec{b}\) vectors as rows rather than columns.
Now, we can interpret each row of the product \(AB\) as the weighted sum of the rows of \(B\), with the weights coming from the corresponding row in \(A\):
\[ \vec{\left(AB\right)}_{i, :} = \sum_{j=1}^{m} A_{i, j} \vec{B}_{j, :} \]This means that we can use a single matrix multiplication in order to compute many weighted sums of vectors in parallel, as long as we pack those vectors as rows of a matrix. You can find an excellent visualization of this here.
We can use this trick to compute our self-attention output vectors. Recall that our output vectors are computed as the weighted sum of the value-embedded input vectors:
\[ \vec{o}_i = \sum_{j \ne i} w_j \left(W^{V}\cdot\vec{v}_j\right) \]Using the trick we just learned, we can represent this as the following
\[ \begin{gather*} V = IW^V \\ O = AV \end{gather*} \]Where the matrix \(V\) is the value-embedded input vectors stacked as rows, the matrix \(A\) is our attention weight matrix from the previous section with the attention weight distributions for each input vector stacked as rows. The matrix \(O\) contains our output vectors, with the ith row of \(O\) being the output for the ith input vector.
Finishing thoughts
We can summarize the whole computation as follows:
\[ \begin{gather*} Q, K, V = IW^Q, IW^K, IW^V \\ O = Softmax\left(\frac{1}{\sqrt{d}} \cdot QK^T) \right)V \end{gather*} \]