# Multi-Headed attention as matrix multiplication

Today I’ll walk through how to implement multi-headed attention with a series of matrix multiplications. Computing attention in this way is more efficient than iteratively visiting each input vector and computing its associated output. Matrix multiplication operations are highly optimized and are designed to be parallelzed across threads on the GPU.

To help us, I’ll also introduce two interpretations of matrix multiplication, both of which are super useful even outside the context of computing multi-headed attention.

# Computing the attention weights

Suppose you had two sets of vectors of identical dimension
\(\vec{a}_1, \vec{a}_2, ..., \vec{a}_n\) and \(\vec{b}_1, \vec{b}_2,
..., \vec{b}_m\) and you construct a matrix \(A\) by stacking the
\(\vec{a}_i\) vectors as *rows*, and a matrix \(B\) by stacking the
\(\vec{b}_i\) vectors as *columns*. You can then interpret the
multiplication \(AB\) as a pairwise similarity matrix:

We can use this to help us compute our attention weights. Recall that in computing a self-attention layer, we iterate through each query vector, and compute the scaled dot product against all of the key vectors. For the ith query vector, we compute the weights with:

\[ w_j \propto \exp\left(\frac{\left(W^{Q}\cdot\vec{v}_i\right) \cdot \left(W^{K}\cdot\vec{v}_j\right)}{\sqrt{d}}\right) \]Note the \(\left(W^{Q} \, \cdot \, \vec{v}_i\right) \, \cdot \,\left(W^{K} \, \cdot \, \vec{v}_j\right)\) term. If we were able to compute all pairwise dot-product similarities, we would be well on our way to computing the attention weights. Lets compute the similarities with a few matrix multiplications:

\[ \begin{gather*} Q = IW^Q \\ K = IW^K \\ S = \frac{1}{\sqrt{d}} \cdot QK^T \end{gather*} \]\(S\) is a scaled pairwise dot-product similarity matrix (\(d\) being the dimension of the query/key embedding space), \(I\) is a matrix constructed by taking our input \(\vec{v}_i\) vectors as rows, and the matrices \(Q\) and \(K\) are the input vectors embedded into query and key space, respectively. That is, the ith row of \(IW^Q\) is \(W^Q\cdot \vec{v}_i\), and the ith row of \(IW^K\) is \(W^{K}\cdot\vec{v}_i\). Note that we need to take the transpose of \(K\) in computing \(S\) in order to stack the key vectors along columns rather than the rows.

Computing the similarity matrix isn’t enough; we need to turn this similarity matrix into attention vectors for each query exponentiating all of the entries and then normalizing each row to sum up to 1. This is the same as taking the softmax of the rows of the matrix.

\[ \begin{gather*} A = Softmax\left(S\right) \end{gather*} \]After taking this softmax, we have the property that the ith row of \(A\) is the attention weight vector for the ith input vector. Concretely, \(A_{i,j}\) tells us, when computing the output for input vector \(\vec{v}_i\), how much weight \(w_j\) to assign input vector \(\vec{v}_j\).

# Computing the outputs

Now, we will leverage a second useful interpretation of matrix
multiplication in order to use our attention matrix to compute output
vectors. Suppose we had vectors \(\vec{a}_1, \vec{a}_2, ...,
\vec{a}_n\) and \(\vec{b}_1, \vec{b}_2, ..., \vec{b}_m\), with the
dimension of the \(\vec{a}\) vectors being \(m\). Construct a matrix
\(A\) by stacking \(\vec{a}_i\) vectors as *rows*, and a matrix \(B\)
by stacking the \(\vec{b}_i\) vectors *also as rows*. Note that this
is different than the set up we had for the other matrix
multiplication interpretation! The key difference is that we are
stacking the \(\vec{b}\) vectors as rows rather than columns.

Now, we can interpret each *row* of the product \(AB\) as the weighted
sum of the rows of \(B\), with the weights coming from the
corresponding *row* in \(A\):

This means that we can use a single matrix multiplication in order to compute many weighted sums of vectors in parallel, as long as we pack those vectors as rows of a matrix. You can find an excellent visualization of this here.

We can use this trick to compute our self-attention output vectors. Recall that our output vectors are computed as the weighted sum of the value-embedded input vectors:

\[ \vec{o}_i = \sum_{j \ne i} w_j \left(W^{V}\cdot\vec{v}_j\right) \]Using the trick we just learned, we can represent this as the following

\[ \begin{gather*} V = IW^V \\ O = AV \end{gather*} \]Where the matrix \(V\) is the value-embedded input vectors stacked as rows, the matrix \(A\) is our attention weight matrix from the previous section with the attention weight distributions for each input vector stacked as rows. The matrix \(O\) contains our output vectors, with the ith row of \(O\) being the output for the ith input vector.

# Finishing thoughts

We can summarize the whole computation as follows:

\[ \begin{gather*} Q, K, V = IW^Q, IW^K, IW^V \\ O = Softmax\left(\frac{1}{\sqrt{d}} \cdot QK^T) \right)V \end{gather*} \]