Skip to content

Commit c814814

Browse files
authored
Grouped-Query Attention memory (#874)
* GQA memory * remove redundant code * update links * update
1 parent b8e12e1 commit c814814

File tree

7 files changed

+1114
-0
lines changed

7 files changed

+1114
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ appendix-D/01_main-chapter-code/3.pdf
1111

1212
appendix-E/01_main-chapter-code/loss-plot.pdf
1313

14+
ch04/04_gqa/kv_bytes_vs_context_length.pdf
15+
ch04/04_gqa/savings_vs_n_kv_groups.pdf
16+
1417
ch05/01_main-chapter-code/loss-plot.pdf
1518
ch05/01_main-chapter-code/temperature-plot.pdf
1619
ch05/01_main-chapter-code/the-verdict.txt

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ Several folders contain optional materials as a bonus for interested readers:
168168
- **Chapter 4: Implementing a GPT model from scratch**
169169
- [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb)
170170
- [KV Cache](ch04/03_kv-cache)
171+
- [Grouped-Query Attention](ch04/04_gqa)
171172
- **Chapter 5: Pretraining on unlabeled data:**
172173
- [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)
173174
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)

ch04/04_gqa/README.md

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Grouped-Query Attention (GQA)
2+
3+
This bonus material illustrates the memory savings when using Grouped-Query Attention (GQA) over regular Multi-Head Attention (MHA).
4+
5+
6+
7+
 
8+
## Introduction
9+
10+
11+
Grouped-Query Attention (GQA) has become the new standard replacement for a more compute- and parameter-efficient alternative to Multi-Head Attention (MHA) in recent years. Note that it's not new and goes back to the 2023 [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245). And even the larger variants in the good old Llama 2 series used it.
12+
13+
Here's a brief GQA summary. Unlike MHA, where each head also has its own set of keys and values, to reduce memory usage, GQA groups multiple heads to share the same key and value projections.
14+
15+
For example, as further illustrated in the figure below, if there are 3 key-value groups and 6 attention heads, then heads 1 and 2 share one set of keys and values, while heads 3 and 4, as well as heads 5 and 6, share another, respectively.
16+
17+
 
18+
19+
![GQA](https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/1.webp?1)
20+
21+
 
22+
23+
This sharing of keys and values reduces the total number of key and value computations, which leads to lower memory usage and improved efficiency.
24+
25+
So, to summarize, the core idea behind GQA is to reduce the number of key and value heads by sharing them across multiple query heads. This (1) lowers the model's parameter count and (2) reduces the memory bandwidth usage for key and value tensors during inference since fewer keys and values need to be stored and retrieved from the KV cache.
26+
27+
While GQA is mainly a computational-efficiency workaround for MHA, ablation studies (such as those in the [original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2 paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in terms of LLM modeling performance.
28+
29+
However, this assumes that the number of key-value groups is chosen carefully. However, if we set the number of key-value heads equal to the number of heads (this special case is known as multi-query attention), it will negatively affect the modeling performance.
30+
31+
32+
33+
 
34+
## GQA Memory Savings
35+
36+
The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula:
37+
38+
bytes ≈ batch_size × seqlen × (embed_dim / n_heads) × n_layers × 2 (K,V) × bytes_per_elem × n_kv_heads
39+
40+
You can use the [memory_estimator.py](memory_estimator.py) script in this folder to apply this for different model configs to see how much memory you can save by using GQA over MHA:
41+
42+
```bash
43+
➜ uv run memory_estimator.py \
44+
--emb_dim 4096 --n_heads 32 --n_layers 32 \
45+
--context_length 32768 --n_kv_groups 4 \
46+
--batch_size 1 --dtype bf16
47+
==== Config ====
48+
context_length : 32768
49+
emb_dim : 4096
50+
n_heads : 32
51+
n_layers : 32
52+
n_kv_groups : 4
53+
batch_size : 1
54+
dtype : bf16 (2 Bytes/elem)
55+
head_dim : 128
56+
GQA n_kv_heads : 8
57+
58+
==== KV-cache totals across all layers ====
59+
MHA total KV cache : 17.18 GB
60+
GQA total KV cache : 4.29 GB
61+
Ratio (MHA / GQA) : 4.00x
62+
Savings (GQA vs MHA): 75.00%
63+
```
64+
65+
The savings when using GQA over MHA are further shown in the plot below for different key-value group sizes:
66+
67+
 
68+
69+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/2.webp?2" alt="GQA" width="500px" />
70+
71+
&nbsp;
72+
73+
And the following plot shows how the KV cache size grows with an increasing context length:
74+
75+
&nbsp;
76+
77+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gqa-memory/3.webp?2" alt="GQA" width="500px" />
78+
79+
&nbsp;
80+
81+
You can reproduce these plots via `uv run plot_memory_estimates.py`.
82+
83+
84+
85+
&nbsp;
86+
## GQA Code Examples
87+
88+
The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_gqa.py](gpt_with_kv_gqa.py) scripts in this folder provide hands-on examples for comparing the MHA and GQA memory usage in the context of a GPT model implementation.
89+
90+
Note that GQA is also used in the [Llama 3](../../ch05/07_gpt_to_llama), [Gemma 3](../../ch05/12_gemma3), and [Qwen3](../../ch05/11_qwen3) bonus materials. However, for simplicity, the code scripts in this folder modify the GPT architecture, which traditionally didn't use GQA.
91+
92+
Note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it.
93+
94+
Also, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache) so the memory savings are more pronounced.
95+
96+
```bash
97+
uv run gpt_with_kv_mha.py \
98+
--max_new_tokens 32768 \
99+
--n_heads 24 \
100+
--n_layers 12
101+
102+
...
103+
104+
Time: 453.81 sec
105+
72 tokens/sec
106+
Max memory allocated: 1.54 GB
107+
```
108+
109+
```bash
110+
uv run gpt_with_kv_gqa.py \
111+
--max_new_tokens 32768 \
112+
--n_heads 24 \
113+
--n_layers 12 \
114+
--n_kv_groups 4
115+
116+
...
117+
118+
Time: 516.33 sec
119+
63 tokens/sec
120+
Max memory allocated: 0.63 GB
121+
```
122+
123+
The reason why we are not seeing such a big saving as in the plots above is 2-fold:
124+
125+
1. I use a smaller configuration to have the model finish the generation in a reasonable time.
126+
2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis).

0 commit comments

Comments
 (0)