Skip to content

Commit 9b95866

Browse files
authored
Multi-Head Latent Attention (#876)
* Multi-Head Latent Attention * update
1 parent bf27ad1 commit 9b95866

15 files changed

+1164
-233
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ appendix-D/01_main-chapter-code/3.pdf
1212
appendix-E/01_main-chapter-code/loss-plot.pdf
1313

1414
ch04/04_gqa/kv_bytes_vs_context_length.pdf
15-
ch04/04_gqa/savings_vs_n_kv_groups.pdf
15+
ch05/05_mla/kv_bytes_vs_context_length.pdf
1616

1717
ch05/01_main-chapter-code/loss-plot.pdf
1818
ch05/01_main-chapter-code/temperature-plot.pdf

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ Several folders contain optional materials as a bonus for interested readers:
169169
- [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb)
170170
- [KV Cache](ch04/03_kv-cache)
171171
- [Grouped-Query Attention](ch04/04_gqa)
172+
- [Multi-Head Latent Attention](ch04/05_mla)
172173
- **Chapter 5: Pretraining on unlabeled data:**
173174
- [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)
174175
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)

ch04/04_gqa/README.md

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22

33
This bonus material illustrates the memory savings when using Grouped-Query Attention (GQA) over regular Multi-Head Attention (MHA).
44

5-
6-
75
 
86
## Introduction
97

10-
118
Grouped-Query Attention (GQA) has become the new standard replacement for a more compute- and parameter-efficient alternative to Multi-Head Attention (MHA) in recent years. Note that it's not new and goes back to the 2023 [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245). And even the larger variants in the good old Llama 2 series used it.
129

1310
Here's a brief GQA summary. Unlike MHA, where each head also has its own set of keys and values, to reduce memory usage, GQA groups multiple heads to share the same key and value projections.
@@ -28,19 +25,17 @@ While GQA is mainly a computational-efficiency workaround for MHA, ablation stud
2825

2926
However, this assumes that the number of key-value groups is chosen carefully. However, if we set the number of key-value heads equal to the number of heads (this special case is known as multi-query attention), it will negatively affect the modeling performance.
3027

31-
32-
3328
 
3429
## GQA Memory Savings
3530

3631
The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula:
3732

3833
bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads
3934

40-
You can use the [memory_estimator.py](memory_estimator.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA:
35+
You can use the [memory_estimator_gqa.py](memory_estimator_gqa.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA:
4136

4237
```bash
43-
➜ uv run memory_estimator.py \
38+
➜ uv run memory_estimator_gqa.py \
4439
--emb_dim 4096 --n_heads 32 --n_layers 32 \
4540
--context_length 32768 --n_kv_groups 4 \
4641
--batch_size 1 --dtype bf16
@@ -62,25 +57,15 @@ Ratio (MHA / GQA) : 4.00x
6257
Savings (GQA vs MHA): 75.00%
6358
```
6459

65-
The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes:
60+
The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes as a function of the context length:
6661

6762
 
6863

69-
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/2.webp?2" alt="GQA" width="500px" />
64+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/3.webp?4" alt="GQA" width="500px" />
7065

7166
&nbsp;
7267

73-
And the following plot shows how the KV cache size grows with an increasing context length:
74-
75-
&nbsp;
76-
77-
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/3.webp?2" alt="GQA" width="500px" />
78-
79-
&nbsp;
80-
81-
You can reproduce these plots via `uv run plot_memory_estimates.py`.
82-
83-
68+
You can reproduce the plot via `uv run plot_memory_estimates_gqa.py`.
8469

8570
&nbsp;
8671
## GQA Code Examples

ch04/04_gqa/gpt_with_kv_gqa.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Code: https://github.com/rasbt/LLMs-from-scratch
55

66
# This file collects all the relevant code that we covered thus far
7-
# throughout Chapters 3-4.
7+
# throughout Chapters 3-4, adapted to use Grouped-Query Attention (GQA).
88
# This file can be run as a standalone script.
99

1010
import argparse
@@ -83,7 +83,8 @@ def forward(self, x, use_cache=False):
8383
# Shape: (b, num_heads, num_tokens, num_tokens)
8484
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
8585

86-
# Use the mask to fill attention scores
86+
####################################################
87+
# causal mask
8788
num_tokens_Q = queries.shape[-2]
8889
num_tokens_K = keys.shape[-2]
8990
device = queries.device
@@ -101,6 +102,7 @@ def forward(self, x, use_cache=False):
101102
k_positions = torch.arange(num_tokens_K, device=device, dtype=torch.long)
102103
mask = q_positions.unsqueeze(-1) < k_positions.unsqueeze(0)
103104

105+
# Use the mask to fill attention scores
104106
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
105107

106108
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
@@ -111,7 +113,7 @@ def forward(self, x, use_cache=False):
111113
context_vec = (attn_weights @ values).transpose(1, 2)
112114

113115
# Combine heads, where self.d_out = self.num_heads * self.head_dim
114-
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
116+
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
115117
context_vec = self.out_proj(context_vec) # optional projection
116118

117119
return context_vec
@@ -184,7 +186,7 @@ def forward(self, x, use_cache=False):
184186

185187
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
186188
####################################################
187-
# NEW
189+
# KV cache-related
188190
x = self.att(x, use_cache=use_cache)
189191
####################################################
190192

@@ -211,7 +213,7 @@ def __init__(self, cfg):
211213
# self.trf_blocks = nn.Sequential(
212214
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
213215
####################################################
214-
# NEW
216+
# KV cache-related
215217
self.trf_blocks = nn.ModuleList(
216218
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
217219

@@ -228,8 +230,7 @@ def forward(self, in_idx, use_cache=False):
228230
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
229231

230232
####################################################
231-
# NEW
232-
233+
# KV cache-related
233234
if use_cache:
234235
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
235236
self.current_pos += seq_len
@@ -243,7 +244,7 @@ def forward(self, in_idx, use_cache=False):
243244

244245
# x = self.trf_blocks(x)
245246
####################################################
246-
# NEW
247+
# KV cache-related
247248
for blk in self.trf_blocks:
248249
x = blk(x, use_cache=use_cache)
249250
####################################################
@@ -253,42 +254,14 @@ def forward(self, in_idx, use_cache=False):
253254
return logits
254255

255256
####################################################
256-
# NEW
257+
# KV cache-related
257258
def reset_kv_cache(self):
258259
for blk in self.trf_blocks:
259260
blk.att.reset_cache()
260261
self.current_pos = 0
261262
####################################################
262263

263264

264-
def generate_text_simple(model, idx, max_new_tokens, context_size):
265-
# idx is (B, T) array of indices in the current context
266-
for _ in range(max_new_tokens):
267-
268-
# Crop current context if it exceeds the supported context size
269-
# E.g., if LLM supports only 5 tokens, and the context size is 10
270-
# then only the last 5 tokens are used as context
271-
idx_cond = idx[:, -context_size:]
272-
273-
# Get the predictions
274-
with torch.no_grad():
275-
logits = model(idx_cond)
276-
277-
# Focus only on the last time step
278-
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
279-
logits = logits[:, -1, :]
280-
281-
# Get the idx of the vocab entry with the highest logits value
282-
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
283-
284-
# Append sampled index to the running sequence
285-
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
286-
287-
return idx
288-
289-
290-
####################################################
291-
# NEW
292265
def generate_text_simple_cached(model, idx, max_new_tokens,
293266
context_size=None, use_cache=True):
294267
model.eval()
@@ -314,7 +287,6 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
314287
idx = torch.cat([idx, next_idx], dim=1)
315288

316289
return idx
317-
####################################################
318290

319291

320292
def main():
@@ -324,6 +296,7 @@ def main():
324296
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
325297
parser.add_argument("--n_kv_groups", type=int, default=2, help="Number of key/value groups.")
326298
parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
299+
327300
args = parser.parse_args()
328301

329302
start_context = "Hello, I am"

ch04/04_gqa/gpt_with_kv_mha.py

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False):
3333
self.dropout = nn.Dropout(dropout)
3434

3535
####################################################
36-
# NEW
36+
# KV cache-related code
3737
self.register_buffer("cache_k", None, persistent=False)
3838
self.register_buffer("cache_v", None, persistent=False)
3939
self.ptr_current_pos = 0
@@ -53,7 +53,7 @@ def forward(self, x, use_cache=False):
5353
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
5454

5555
####################################################
56-
# NEW
56+
# KV cache-related
5757
if use_cache:
5858
if self.cache_k is None:
5959
self.cache_k, self.cache_v = keys_new, values_new
@@ -74,7 +74,7 @@ def forward(self, x, use_cache=False):
7474
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
7575

7676
####################################################
77-
# NEW
77+
# causal mask
7878
num_tokens_Q = queries.shape[-2]
7979
num_tokens_K = keys.shape[-2]
8080
device = queries.device
@@ -107,12 +107,9 @@ def forward(self, x, use_cache=False):
107107

108108
return context_vec
109109

110-
####################################################
111-
# NEW
112110
def reset_cache(self):
113111
self.cache_k, self.cache_v = None, None
114112
self.ptr_current_pos = 0
115-
####################################################
116113

117114

118115
#####################################
@@ -177,7 +174,7 @@ def forward(self, x, use_cache=False):
177174

178175
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
179176
####################################################
180-
# NEW
177+
# KV cache-related
181178
x = self.att(x, use_cache=use_cache)
182179
####################################################
183180

@@ -204,7 +201,7 @@ def __init__(self, cfg):
204201
# self.trf_blocks = nn.Sequential(
205202
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
206203
####################################################
207-
# NEW
204+
# KV cache-related
208205
self.trf_blocks = nn.ModuleList(
209206
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
210207

@@ -221,8 +218,7 @@ def forward(self, in_idx, use_cache=False):
221218
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
222219

223220
####################################################
224-
# NEW
225-
221+
# KV cache-related
226222
if use_cache:
227223
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
228224
self.current_pos += seq_len
@@ -236,7 +232,7 @@ def forward(self, in_idx, use_cache=False):
236232

237233
# x = self.trf_blocks(x)
238234
####################################################
239-
# NEW
235+
# KV cache-related
240236
for blk in self.trf_blocks:
241237
x = blk(x, use_cache=use_cache)
242238
####################################################
@@ -246,42 +242,14 @@ def forward(self, in_idx, use_cache=False):
246242
return logits
247243

248244
####################################################
249-
# NEW
245+
# KV cache-related
250246
def reset_kv_cache(self):
251247
for blk in self.trf_blocks:
252248
blk.att.reset_cache()
253249
self.current_pos = 0
254250
####################################################
255251

256252

257-
def generate_text_simple(model, idx, max_new_tokens, context_size):
258-
# idx is (B, T) array of indices in the current context
259-
for _ in range(max_new_tokens):
260-
261-
# Crop current context if it exceeds the supported context size
262-
# E.g., if LLM supports only 5 tokens, and the context size is 10
263-
# then only the last 5 tokens are used as context
264-
idx_cond = idx[:, -context_size:]
265-
266-
# Get the predictions
267-
with torch.no_grad():
268-
logits = model(idx_cond)
269-
270-
# Focus only on the last time step
271-
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
272-
logits = logits[:, -1, :]
273-
274-
# Get the idx of the vocab entry with the highest logits value
275-
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
276-
277-
# Append sampled index to the running sequence
278-
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
279-
280-
return idx
281-
282-
283-
####################################################
284-
# NEW
285253
def generate_text_simple_cached(model, idx, max_new_tokens,
286254
context_size=None, use_cache=True):
287255
model.eval()
@@ -307,7 +275,6 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
307275
idx = torch.cat([idx, next_idx], dim=1)
308276

309277
return idx
310-
####################################################
311278

312279

313280
def main():
@@ -316,6 +283,7 @@ def main():
316283
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
317284
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
318285
parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
286+
319287
args = parser.parse_args()
320288

321289
start_context = "Hello, I am"
File renamed without changes.

0 commit comments

Comments
 (0)