Multi-Headed Attention
Let’s build off of the vanilla self-attention layer I described in my last post to construct a multi-headed attention layer. The intuition behind multi-headed attention is that different input vectors might relate to each other semantically in multiple ways. Consider the sentence “I am going to deposit my money at the bank”. When computing an output vector for “deposit”, it is likely important to attend to “bank” as the other side of the connecting preposition “at”. Similarly, we’d also likely need to attend to “I”, since “I” is the actor that is performing the action “deposit”.
But, here’s the rub. The relationship between “deposit” vs “bank” and “deposit” vs “I” are quite different. One is a prepositional relationship while the other is a subject-action relationship. Therefore, we might like to produce two separate output vectors that capture these two different senses, and use distinct attention distributions in computing them. The heads of a multi-attention layer are designed to compute these distinct output vectors and attention distributions.
A single head, attempt 1
To start off simply, let us first construct a single head by re-using the machinery we already have from self-attention. Recall that this head wants to construct attention distributions that capture a specific kind of semantic relationship between words. This means, intuitively, that before running self-attention on the vectors, we could embed them in a new space in which closeness represents that semantic relationship. Let’s represent that embedding operation with a matrix \(W\). Then, we could embed all the input \(\vec{v}_i\)’s using the matrix \(W\), and then run self-attention on those intermediate embeddings. That would look something like this:
\[ \begin{gather*} \vec{o}_i = \sum_{j \ne i} w_j \vec{v}_j \\ w_j \propto \exp\left(\frac{\left(W\cdot\vec{v}_i\right) \cdot \left(W\cdot\vec{v}_j\right)}{\sqrt{d}}\right) \end{gather*} \]Since we don’t know which semantic embeddings are useful for solving the underlying task, the parameters of \(W\) would be trainable. Also, note \(d\) now changes to the dimension of the embedding space, to remain consistent with its original purpose.
A single head, attempt 2
This gets us close to how multi-headed attention is computed in the literature. But, we aren’t quite there yet. Notice that many relationships between words are not symmetric. For example, in our example sentence, the tokens “I” and “deposit” are subject-verb related. But, “I” plays the distinct role of “subject”, and “deposit” plays the role of “verb”. Flipping those roles would create a sentence with a very different meaning. Unfortunately, our current embedding enforces these semantic relationships to be symmetric since both the query vector and the non-query vectors are all embedded with the same matrix, and the dot-product operation is symmetric:
\[ \begin{gather*} w_j \propto \exp\left(\frac{\left(W\cdot\vec{v}_i\right) \cdot \left(W\cdot\vec{v}_j\right)}{\sqrt{d}}\right) \end{gather*} \]If we wanted to break this symmetry, we can replace the single \(W\) matrix with two: one matrix \(W^{Q}\) that transforms only the query vectors, and a second matrix \(W^{K}\) that transforms the non-query vectors. In the literature, you’ll find that these non-query vectors are referred to as keys – hence the K superscript for the matrix.
This gives us the following computation for our single attention head:
\[ \begin{gather*} \vec{o}_i = \sum_{j \ne i} w_j \vec{v}_j \\ w_j \propto \exp\left(\frac{\left(W^{Q}\cdot\vec{v}_i\right) \cdot \left(W^{K}\cdot\vec{v}_j\right)}{\sqrt{d}}\right) \end{gather*} \]A single head, final attempt
Now that we have broken the symmetry between queries and keys, our last step will take a closer look at the computation of the output vector:
\[ \vec{o}_i = \sum_{j \ne i} w_j \vec{v}_j \]As the simple weighted sum of the input vectors, this output doesn’t have much flexibility for expression. What if the vectors need to participate in the output in different ways? From our example sentence, the tokens “I”, “deposit”, and “money” all participate in a “subject-object-verb” relationship. If we imagine the embedding for “I” as the query vector, it might be the case that “deposit” and “money” are all equally important in terms of attention. However, we may need to push forward into the output the fact that we are using the token “deposit” with its verb sense rather than, say, its noun sense, and that we are using “money” as an object rather than the subject.
A solution to this is yet another transformation matrix \(W^{V}\) that takes the input vectors and transforms them into an embedding space suitable for combination in the final weighted sum. In the literature, you’ll find that these vectors are referred to as values, giving us the “V” superscript for the matrix. This leaves us with the following final set of equations for computing a single attention head:
\[ \begin{gather*} \vec{o}_i = \sum_{j \ne i} w_j \left(W^{V}\cdot\vec{v}_j\right) \\ w_j \propto \exp\left(\frac{\left(W^{Q}\cdot\vec{v}_i\right) \cdot \left(W^{K}\cdot\vec{v}_j\right)}{\sqrt{d}}\right) \end{gather*} \]We will define this entire computation as:
\[ \begin{gather*} o_1, o_2, ..., o_n = Attention\left(v_1, v_2, ..., v_n; \left[W^Q, W^K, W^V\right]\right) \end{gather*} \]To summarize, we introduced three new learnable matrices \(W^{K}\), \(W^{Q}\), and \(W^{V}\) that all transform the input vectors into key, query, and value embedding spaces. Closeness between vectors in the query and keys spaces is used to compute attention weights, while the value embedding space is used in computing the final output vector.
Putting our heads together
Now that we’ve described how to construct a single head, bringing this back to multi-headed attention is fairly straightforward. The idea is to define \(K\) single headed attention layers, with their own uniquely defined query-key-value matrices, then execute them in parallel against the same set of input vectors. Finally, we concatenate the output vectors coming from the different heads together into a single output vector. We’ll introduce a subscript to the matrices indicating the index of the head to which they belong. We’ll also introduce a superscript to the output vectors indicating the head from which they were computed.
For \(i = 1...K\):
\[ \begin{gather*} o^{(i)}_1, o^{(i)}_2, ..., o^{(i)}_n = Attention\left(v_1, v_2, ..., v_n; \left[W^Q_i, W^K_i, W^V_i\right]\right) \end{gather*} \]\[ o_1, o_2, ..., o_n = Concat\left(o^{\left(.\right)}_1\right), Concat\left(o^{\left(.\right)}_2\right), ..., Concat\left(o^{\left(.\right)}_n\right) \]Conclusion
That is all for multi-headed attention! In a future post I discuss how to pack the attention computations into matrix multiplications for GPU optimization. I’ll likely also write a post showing how to actually build these attention layers in tensorflow. Thank you for reading!