Week 5 Session 3: Practical Implementation and Visualization of Transformers#
Introduction#
In this session, we will focus on the practical aspects of the Transformer architecture. We’ll visualize how attention mechanisms work and walk through implementing a Transformer model using Python and PyTorch. This hands-on approach will deepen your understanding of how Transformers process and generate language.
Visualization of Attention Mechanisms#
Understanding how attention works internally is crucial. Visualizations can help demystify the “black box” nature of neural networks.
Visualizing Attention Scores#
Attention scores indicate how much focus the model places on different parts of the input when generating each part of the output.
Example: Attention Heatmap#
Suppose we have an input sentence:
Input: “The cat sat on the mat.”
When processing this sentence, the model computes attention scores between each pair of words.
Diagram: Attention Heatmap Matrix#
The cat sat on the mat
+-------------------------------------
The | 0.1 0.2 0.3 0.1 0.2 0.1
cat | 0.2 0.1 0.3 0.1 0.2 0.1
sat | 0.1 0.2 0.1 0.3 0.2 0.1
on | 0.1 0.1 0.2 0.1 0.3 0.2
the | 0.2 0.1 0.1 0.2 0.1 0.3
mat | 0.1 0.2 0.1 0.2 0.1 0.3
Figure 1: An example attention heatmap showing attention scores between words.
Interpretation#
High Scores: Indicate strong attention between words.
Symmetry: In self-attention, the matrix is often symmetric.
Focus: Words like “the” may have lower attention scores due to being common.
Interpreting Attention Maps#
Visual attention maps help in understanding which parts of the input influence the output.
Example: Translation Task#
Source Sentence: “Das ist ein Beispiel.”
Target Sentence: “This is an example.”
Diagram: Cross-Attention Map#
This is an example
+-----------------------------
Das | 0.7 0.1 0.1 0.1
ist | 0.1 0.7 0.1 0.1
ein | 0.1 0.1 0.7 0.1
Beispiel | 0.1 0.1 0.1 0.7
Figure 2: Cross-attention map between source and target sentences.
Interpretation#
Alignment: High attention scores align words with their translations.
Disambiguation: Helps identify how the model handles word order and syntax differences.
Practical Implementation of a Transformer#
Let’s dive into building a simple Transformer model to solidify our understanding.
Building a Simple Transformer from Scratch#
We will implement a minimal Transformer model using PyTorch, focusing on key components:
Embedding Layer
Positional Encoding
Multi-Head Attention
Feed-Forward Network
Encoder and Decoder Layers
Implementing with PyTorch#
Libraries Needed#
import torch
import torch.nn as nn
import math
Code Walkthrough#
1. Positional Encoding#
Positional encoding injects sequence order information.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :].to(x.device)
return x
2. Multi-Head Attention Module#
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# Linear projections
q = self.q_linear(q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)
k = self.k_linear(k).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)
v = self.v_linear(v).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)
# Scaled Dot-Product Attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention = torch.softmax(scores, dim=-1)
x = torch.matmul(attention, v)
# Concatenate heads
x = x.transpose(1,2).contiguous().view(bs, -1, self.d_k * self.n_heads)
return self.out_proj(x)
3. Feed-Forward Network#
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, d_ff=2048):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
4. Encoder Layer#
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = FeedForwardNetwork(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Self-Attention
attn_output = self.attention(x, x, x, mask)
x = self.norm1(x + attn_output)
# Feed-Forward Network
ffn_output = self.ffn(x)
x = self.norm2(x + ffn_output)
return x
5. Decoder Layer#
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.cross_attention = MultiHeadAttention(d_model, n_heads)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = FeedForwardNetwork(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Masked Self-Attention
self_attn_output = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self_attn_output)
# Cross-Attention
cross_attn_output = self.cross_attention(x, enc_output, enc_output, src_mask)
x = self.norm2(x + cross_attn_output)
# Feed-Forward Network
ffn_output = self.ffn(x)
x = self.norm3(x + ffn_output)
return x
6. Assembling the Transformer#
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, num_layers=6):
super().__init__()
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads) for _ in range(num_layers)])
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# Embedding and Positional Encoding
src = self.pos_encoder(self.src_embedding(src))
tgt = self.pos_encoder(self.tgt_embedding(tgt))
# Encoder
for layer in self.encoder_layers:
src = layer(src, src_mask)
# Decoder
for layer in self.decoder_layers:
tgt = layer(tgt, src, src_mask, tgt_mask)
output = self.fc_out(tgt)
return output
7. Generating Masks#
Masks are essential to prevent the model from attending to future tokens during training.
def generate_square_subsequent_mask(sz):
mask = torch.triu(torch.ones(sz, sz), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
Advanced Topics#
Transformer Variants#
BERT (Bidirectional Encoder Representations from Transformers)#
Architecture: Uses only the encoder part of the Transformer.
Objective: Masked Language Modeling (MLM) and Next Sentence Prediction (NSP).
Usage: Excellent for tasks requiring understanding of the context, such as question answering and sentiment analysis.
GPT (Generative Pre-trained Transformer)#
Architecture: Utilizes the decoder part of the Transformer.
Objective: Language modeling (predicting the next word).
Usage: Effective for text generation tasks.
Applications of Transformers#
Machine Translation: High-quality translation between languages.
Text Summarization: Generating concise summaries of documents.
Question Answering: Providing answers based on context.
Text Generation: Creative writing, code generation, and more.
Optional Coding Task Walkthrough#
Implementing a Simple Attention Mechanism#
Self-Attention Function#
def simple_self_attention(Q, K, V):
scores = torch.matmul(Q, K.transpose(-2, -1))
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output
Example Usage#
# Sample input tensor
x = torch.rand(1, 5, 64) # (batch_size, sequence_length, d_model)
# Assuming Q, K, V are all x in self-attention
output = simple_self_attention(x, x, x)
print(output.shape) # Should output (1, 5, 64)
Conclusion#
In this session, we’ve taken a hands-on approach to understanding Transformers. By visualizing attention mechanisms and implementing a Transformer model from scratch, you should now have a deeper appreciation for how these models function internally. This practical experience is invaluable as we move toward working with large language models and APIs in the coming weeks.
References#
Vaswani et al., “Attention is All You Need” (Paper Link)
The Illustrated Transformer by Jay Alammar