May 13, 2019

Multi-headed attention as matrix multiplication

Today I’ll walk through how to implement multi-headed attention with a series of matrix multiplications. Computing attention in this way is more efficient than iteratively visiting each input vector and computing its associated output. Matrix multiplication operations are highly optimized and are designed to be parallelzed across threads on the GPU.

To help us, I’ll also introduce two interpretations of matrix multiplication, both of which are super useful even outside the context of computing multi-headed attention.

Computing the attention weights

Suppose you had two sets of vectors of identical dimension a1,a2,...,an_1, _2, …, _na1,a2,...,an and b1,b2,...,bm_1, _2, …, _mb1,b2,...,bm and you construct a matrix AAA by stacking the ai_iai vectors as rows, and a matrix BBB by stacking the bi_ibi vectors as columns. You can then interpret the multiplication ABABAB as a pairwise similarity matrix:

(AB)i,j=aibj(AB)i,j=aibj

We can use this to help us compute our attention weights. Recall that in computing a self-attention layer, we iterate through each query vector, and compute the scaled dot product against all of the key vectors. For the ith query vector, we compute the weights with:

wjexp((WQvi)(WKvj)d)wjexp(d(WQvi)(WKvj))

Note the (WQvi)(WKvj)(W^{Q}_i) (W^{K}_j)(WQvi)(WKvj) term. If we were able to compute all pairwise dot-product similarities, we would be well on our way to computing the attention weights. Lets compute the similarities with a few matrix multiplications:

Q=IWQK=IWKS=1dQKTQKS=IWQ=IWK=d1QKT

SSS is a scaled pairwise dot-product similarity matrix (ddd being the dimension of the query/key embedding space), III is a matrix constructed by taking our input vi_ivi vectors as rows, and the matrices QQQ and KKK are the input vectors embedded into query and key space, respectively. That is, the ith row of IWQIW^QIWQ is WQviW^Q_iWQvi, and the ith row of IWKIW^KIWK is WKviW^{K}_iWKvi. Note that we need to take the transpose of KKK in computing SSS in order to stack the key vectors along columns rather than the rows.

Computing the similarity matrix isn’t enough; we need to turn this similarity matrix into attention vectors for each query exponentiating all of the entries and then normalizing each row to sum up to 1. This is the same as taking the softmax of the rows of the matrix.

A=Softmax(S)A=Softmax(S)

After taking this softmax, we have the property that the ith row of AAA is the attention weight vector for the ith input vector. Concretely, Ai,jA_{i,j}Ai,j tells us, when computing the output for input vector vi_ivi, how much weight wjw_jwj to assign input vector vj_jvj.

Computing the outputs

Now, we will leverage a second useful interpretation of matrix multiplication in order to use our attention matrix to compute output vectors. Suppose we had vectors a1,a2,...,an_1, _2, …, _na1,a2,...,an and b1,b2,...,bm_1, _2, …, _mb1,b2,...,bm, with the dimension of the aa vectors being mmm. Construct a matrix AAA by stacking ai_iai vectors as rows, and a matrix BBB by stacking the bi_ibi vectors also as rows. Note that this is different than the set up we had for the other matrix multiplication interpretation! The key difference is that we are stacking the bb vectors as rows rather than columns.

Now, we can interpret each row of the product ABABAB as the weighted sum of the rows of BBB, with the weights coming from the corresponding row in AAA:

(AB)i,:=j=1mAi,jBj,:(AB)i,:=j=1mAi,jBj,:

This means that we can use a single matrix multiplication in order to compute many weighted sums of vectors in parallel, as long as we pack those vectors as rows of a matrix. You can find an excellent visualization of this here.

We can use this trick to compute our self-attention output vectors. Recall that our output vectors are computed as the weighted sum of the value-embedded input vectors:

oi=j̸=iwj(WVvj)oi=j̸=iwj(WVvj)

Using the trick we just learned, we can represent this as the following

V=IWVO=AVVO=IWV=AV

Where the matrix VVV is the value-embedded input vectors stacked as rows, the matrix AAA is our attention weight matrix from the previous section with the attention weight distributions for each input vector stacked as rows. The matrix OOO contains our output vectors, with the ith row of OOO being the output for the ith input vector.

Finishing thoughts

We can summarize the whole computation as follows:

Q,K,V=IWQ,IWK,IWVO=Softmax(1dQKT))VQ,K,VO=IWQ,IWK,IWV=Softmax(d1QKT))V

This finally gets us to how multi-headed attention is expressed in the Attention is all you need paper. It took three posts to unpack the intuition and reasoning behind the equation!


Machine learning Transformer models Attention


Previous post
Multi-headed attention Let’s build off of the vanilla self-attention layer I described in my last post to construct a multi-headed attention layer. The intuition behind
Next post
May Papers Lately I’ve been interested in how machine learning can help artists create animations for characters and critters. So, you’ll find that many of the