Skip to content

Commit bf27ad1

Browse files
authored
Use GB instead of GiB consistently (#875)
1 parent c814814 commit bf27ad1

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

ch04/04_gqa/memory_estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,11 @@ def main():
4949
"n_kv_groups": args.n_kv_groups,
5050
}
5151

52+
if cfg["n_heads"] % cfg["n_kv_groups"] != 0:
53+
raise ValueError("n_kv_groups must divide n_heads exactly.")
54+
5255
bytes_per_elem = DTYPE_BYTES[args.dtype]
53-
head_dim = cfg["emb_dim"] / cfg["n_heads"]
56+
head_dim = math.ceil(cfg["emb_dim"] / cfg["n_heads"])
5457

5558
n_kv_heads_mha = cfg["n_heads"]
5659
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
@@ -83,7 +86,7 @@ def main():
8386
print(f"{k:17}: {v}")
8487
print(f"batch_size : {args.batch_size}")
8588
print(f"dtype : {args.dtype} ({bytes_per_elem} Bytes/elem)")
86-
print(f"head_dim : {int(head_dim)}")
89+
print(f"head_dim : {head_dim}")
8790
print(f"GQA n_kv_heads : {n_kv_heads_gqa}")
8891
print()
8992

ch04/04_gqa/plot_memory_estimates.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from memory_estimator import kv_bytes_total, DTYPE_BYTES
1313

1414

15+
def bytes_convert(n):
16+
gb = n / (1000 ** 3)
17+
return f"{gb:.2f}"
18+
19+
1520
def savings_percent(total_mha, total_gqa):
1621
return (1.0 - (total_gqa / total_mha)) * 100.0
1722

@@ -83,8 +88,8 @@ def plot_abs_kv_vs_context():
8388
]
8489

8590
xs = []
86-
mha_gib = []
87-
gqa_gib = []
91+
mha_gb = []
92+
gqa_gb = []
8893
savings_pct = None
8994

9095
for L in context_lengths:
@@ -97,14 +102,14 @@ def plot_abs_kv_vs_context():
97102
n_kv_heads_gqa, n_layers, bytes_per_elem
98103
)
99104
xs.append(L)
100-
mha_gib.append(total_mha / (1024**3))
101-
gqa_gib.append(total_gqa / (1024**3))
105+
mha_gb.append(float(bytes_convert(total_mha)))
106+
gqa_gb.append(float(bytes_convert(total_gqa)))
102107
if savings_pct is None:
103108
savings_pct = savings_percent(total_mha, total_gqa)
104109

105110
plt.figure()
106-
plt.plot(xs, mha_gib, marker="o", label="MHA (KV total)")
107-
plt.plot(xs, gqa_gib, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
111+
plt.plot(xs, mha_gb, marker="o", label="MHA (KV total)")
112+
plt.plot(xs, gqa_gb, marker="o", label=f"GQA (n_kv_groups={n_kv_groups})")
108113
plt.xscale("log")
109114
plt.xlabel("context_length (log scale)")
110115
plt.ylabel("Total KV cache (GB)")
@@ -118,6 +123,9 @@ def plot_abs_kv_vs_context():
118123
plt.tight_layout()
119124
plt.savefig("kv_bytes_vs_context_length.pdf")
120125
print(f"Savings is constant across lengths: ~{savings_pct:.2f}%.")
126+
print(f"Example (context_length={context_lengths[-1]}): "
127+
f"MHA={bytes_convert(total_mha)} GB, "
128+
f"GQA={bytes_convert(total_gqa)} GB")
121129

122130

123131
if __name__ == "__main__":

0 commit comments

Comments
 (0)