1212from memory_estimator import kv_bytes_total , DTYPE_BYTES
1313
1414
15+ def bytes_convert (n ):
16+ gb = n / (1000 ** 3 )
17+ return f"{ gb :.2f} "
18+
19+
1520def savings_percent (total_mha , total_gqa ):
1621 return (1.0 - (total_gqa / total_mha )) * 100.0
1722
@@ -83,8 +88,8 @@ def plot_abs_kv_vs_context():
8388 ]
8489
8590 xs = []
86- mha_gib = []
87- gqa_gib = []
91+ mha_gb = []
92+ gqa_gb = []
8893 savings_pct = None
8994
9095 for L in context_lengths :
@@ -97,14 +102,14 @@ def plot_abs_kv_vs_context():
97102 n_kv_heads_gqa , n_layers , bytes_per_elem
98103 )
99104 xs .append (L )
100- mha_gib .append (total_mha / ( 1024 ** 3 ))
101- gqa_gib .append (total_gqa / ( 1024 ** 3 ))
105+ mha_gb .append (float ( bytes_convert ( total_mha ) ))
106+ gqa_gb .append (float ( bytes_convert ( total_gqa ) ))
102107 if savings_pct is None :
103108 savings_pct = savings_percent (total_mha , total_gqa )
104109
105110 plt .figure ()
106- plt .plot (xs , mha_gib , marker = "o" , label = "MHA (KV total)" )
107- plt .plot (xs , gqa_gib , marker = "o" , label = f"GQA (n_kv_groups={ n_kv_groups } )" )
111+ plt .plot (xs , mha_gb , marker = "o" , label = "MHA (KV total)" )
112+ plt .plot (xs , gqa_gb , marker = "o" , label = f"GQA (n_kv_groups={ n_kv_groups } )" )
108113 plt .xscale ("log" )
109114 plt .xlabel ("context_length (log scale)" )
110115 plt .ylabel ("Total KV cache (GB)" )
@@ -118,6 +123,9 @@ def plot_abs_kv_vs_context():
118123 plt .tight_layout ()
119124 plt .savefig ("kv_bytes_vs_context_length.pdf" )
120125 print (f"Savings is constant across lengths: ~{ savings_pct :.2f} %." )
126+ print (f"Example (context_length={ context_lengths [- 1 ]} ): "
127+ f"MHA={ bytes_convert (total_mha )} GB, "
128+ f"GQA={ bytes_convert (total_gqa )} GB" )
121129
122130
123131if __name__ == "__main__" :
0 commit comments