In the world of artificial intelligence and deep learning, the self-attention mechanism has revolutionized how we process sequential data, particularly in natural language processing tasks. This blog post will explore the self-attention mechanism, implement it using PyTorch, and discuss its applications and benefits
📘 What is Self-Attention?
Self-attention, also known as intra-attention, is a mechanism that allows a model to focus on different parts of the input sequence when producing an output. It enables the model to weigh the importance of different elements in the input dynamically, rather than treating all inputs equally.
📘 Self-Attention vs. Weighted Graph Algorithms
While self-attention shares some similarities with weighted graph algorithms, it has several unique characteristics that make it particularly powerful for deep learning tasks:
- Dynamic Weighting: Unlike fixed weights in graph algorithms, self-attention computes weights (attention scores) dynamically based on the input. This allows the model to adjust its focus for each specific input.
- Learned Representations: Self-attention projects the input into learned query, key, and value spaces. This allows the model to develop sophisticated ways of relating different parts of the input.
- Parallelization: Self-attention computes relationships between all pairs of elements simultaneously, making it highly efficient on modern hardware like GPUs.
- Contextual Understanding: In self-attention, every element can potentially attend to every other element, allowing for a rich, contextual understanding of the input.
- Differentiability: Self-attention is fully differentiable, making it easy to integrate into neural networks and train end-to-end.
- Multi-Head Attention: Many self-attention systems use multi-head attention, allowing the model to focus on different types of relationships simultaneously.
The Attention Class
We start by defining an Attention class that inherits from nn.Module:
Here’s what each part does:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.d_in = d_in
self.d_out = d_out
self.Q = nn.Linear(d_in, d_out)
self.K = nn.Linear(d_in, d_out)
self.V = nn.Linear(d_in, d_out)
d_in: The size of the input vectors.d_out: The size of the output vectors.nn.Linear: This creates a linear transformation layer. In PyTorch,nn.Linear(a, b)creates a layer that transforms input of sizeato output of sizeb.self.Q,self.K,self.V: These are linear transformations for creating Queries, Keys, and Values respectively.
The Forward Pass
Now, let’s implement the forward pass of our attention mechanism:
def forward(self, x):
queries = self.Q(x)
keys = self.K(x)
values = self.V(x)
scores = torch.bmm(queries, keys.transpose(1, 2))
scores = scores / (self.d_out ** 0.5)
attention = F.softmax(scores, dim=2)
hidden_states = torch.bmm(attention, values)
return hidden_states
Let’s break this down:
- Projecting inputs: We project the input
xinto Queries, Keys, and Values using our linear transformations. - Computing attention scores:
torch.bmm: This function performs batch matrix multiplication. It’s used here to compute the dot product between queries and keys.keys.transpose(1, 2): This transposes the last two dimensions of the keys tensor, allowing for correct matrix multiplication.
- Scaling: We scale the scores by
sqrt(d_out)to counteract the effect of large dot products in high dimensions. - Softmax:
F.softmaxapplies the softmax function along dimension 2 (the last dimension), converting scores to probabilities. - Computing weighted sum: Another
torch.bmmis used to compute the weighted sum of values, where the weights are the attention probabilities.
Understanding the Mathematics
The self-attention mechanism can be summarized by this equation:
Attention(Q, K, V) = softmax((QK^T) / sqrt(d_k)) V
where:
- softmax((QK^T) / √d_k) ∈ ℝ^(n×m) is the softmax function applied row-wise to the scaled dot products
- A ∈ ℝ^(n×d_v) is the output of the attention mechanism
- QK^T ∈ ℝ^(n×m) is the matrix product of Q and the transpose of K
- softmax(x_i) = exp(x_i) / (∑_j exp(x_j))
This formula shows how each output element is a weighted sum of the values, where the weights are determined by the compatibility of the query with all the keys
Example: Multi-Head Attention
Let’s extend our Attention class to implement multi-head attention:
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, num_heads):
super().__init__()
self.heads = nn.ModuleList([Attention(d_in, d_out) for _ in range(num_heads)])
self.linear = nn.Linear(d_out * num_heads, d_out)
def forward(self, x):
return self.linear(torch.cat([h(x) for h in self.heads], dim=-1))
This implementation allows the model to compute multiple different attention patterns in parallel and then combine them, capturing various types of relationships within the data.
Leave a comment