Skip to content

Commit 1174780

Browse files
authored
[mcore-bridge] optimize gpt_bridge comm (#6659)
1 parent 4da83ad commit 1174780

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

swift/llm/argument/export_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ class ExportArguments(MergeArguments, BaseArguments):
6767
to_peft_format: bool = False
6868
exist_ok: bool = False
6969

70+
def load_args_from_ckpt(self) -> None:
71+
if self.to_cached_dataset:
72+
return
73+
super().load_args_from_ckpt()
74+
7075
def _init_output_dir(self):
7176
if self.output_dir is None:
7277
ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}'

swift/megatron/model/gpt_bridge.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False):
6464
self.etp_rank = mpu.get_expert_tensor_parallel_rank()
6565
self.ep_rank = mpu.get_expert_model_parallel_rank()
6666

67+
dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size
68+
expert_decoder_rank_generator = mpu.RankGenerator(
69+
tp=self.etp_size,
70+
ep=self.ep_size,
71+
dp=dp_size,
72+
pp=self.pp_size,
73+
cp=1,
74+
order='tp-cp-ep-dp-pp',
75+
rank_offset=0,
76+
)
77+
rank = dist.get_rank()
78+
for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'):
79+
group = mpu.create_group(
80+
ranks,
81+
group_desc='EP-PP-GROUP',
82+
)
83+
if rank in ranks:
84+
self.ep_pp_size = self.ep_size * self.pp_size
85+
self.ep_pp_group = group
86+
self.ep_pp_rank = dist.get_rank(group)
87+
6788
def _init_meta_hf_model(self):
6889
with torch.device('meta'):
6990
self.hf_model, self.processor = get_model_tokenizer(
@@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
198219
tensor = None if mg_weight is None else mg_weight.to('cuda')
199220
tp_size = self.etp_size if is_expert else self.tp_size
200221
tp_group = self.etp_group if is_expert else self.tp_group
222+
pp_group = self.ep_pp_group if is_expert else self.pp_group
223+
pp_size = self.ep_pp_size if is_expert else self.pp_size
224+
pp_rank = self.ep_pp_rank if is_expert else self.pp_rank
201225
if tensor is not None and tp_dim is not None and tp_size > 1:
202226
if tp_dim == 0:
203227
# save memory
@@ -220,34 +244,26 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
220244
tensor = torch.cat(output, dim=tp_dim)
221245
del output
222246
# pp/ep
223-
for parallel_state in ['ep', 'pp']:
224-
if parallel_state == 'pp' and self.pp_size > 1:
225-
parallel_group = self.pp_group
226-
parallel_rank = self.pp_rank
227-
elif parallel_state == 'ep' and is_expert and self.ep_size > 1:
228-
parallel_group = self.ep_group
229-
parallel_rank = self.ep_rank
230-
else:
231-
continue
232-
src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda')
233-
dist.all_reduce(src_rank, group=parallel_group)
234-
src_rank = dist.get_global_rank(parallel_group, src_rank.item())
247+
if pp_size > 1:
248+
src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda')
249+
dist.all_reduce(src_rank, group=pp_group)
250+
src_rank = dist.get_global_rank(pp_group, src_rank.item())
235251
meta_data = torch.zeros(10, dtype=torch.int64, device='cuda')
236252
dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3}
237253
dtype_mapping_r = {v: k for k, v in dtype_mapping.items()}
238254
if tensor is None:
239-
dist.broadcast(meta_data, src=src_rank, group=parallel_group)
240-
if meta_data[0].item() > 0:
241-
shape = meta_data[1:1 + meta_data[0]].tolist()
242-
dtype = dtype_mapping_r[meta_data[-1].item()]
243-
tensor = torch.empty(shape, device='cuda', dtype=dtype)
244-
dist.broadcast(tensor, src=src_rank, group=parallel_group)
255+
dist.broadcast(meta_data, src=src_rank, group=pp_group)
256+
assert meta_data[0].item() > 0, f'meta_data: {meta_data}'
257+
shape = meta_data[1:1 + meta_data[0]].tolist()
258+
dtype = dtype_mapping_r[meta_data[-1].item()]
259+
tensor = torch.empty(shape, device='cuda', dtype=dtype)
260+
dist.broadcast(tensor, src=src_rank, group=pp_group)
245261
else:
246262
meta_data[0] = tensor.ndim
247263
meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda')
248264
meta_data[-1] = dtype_mapping[tensor.dtype]
249-
dist.broadcast(meta_data, src=src_rank, group=parallel_group)
250-
dist.broadcast(tensor, src=src_rank, group=parallel_group)
265+
dist.broadcast(meta_data, src=src_rank, group=pp_group)
266+
dist.broadcast(tensor, src=src_rank, group=pp_group)
251267
assert tensor is not None, f'mg_key: {mg_key}'
252268
if offset:
253269
tensor = tensor + offset
@@ -273,10 +289,10 @@ def _set_state_dict(self,
273289
is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper)
274290
if not to_mcore:
275291
state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda')
276-
if self.pp_size > 1:
292+
if is_expert and self.ep_pp_size > 1:
293+
dist.all_reduce(state, group=self.ep_pp_group)
294+
elif not is_expert and self.pp_size > 1:
277295
dist.all_reduce(state, group=self.pp_group)
278-
if is_expert and self.ep_size > 1:
279-
dist.all_reduce(state, group=self.ep_group)
280296
is_lora, is_modules_to_save = state
281297
if is_lora and self._is_peft_format and param_key != 'layer_norm_weight':
282298
if to_mcore:
@@ -627,10 +643,10 @@ def _set_mlp_state(self,
627643
is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1,
628644
LoraParallelLinear) and self._is_peft_format
629645
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
630-
if self.pp_size > 1:
646+
if is_expert and self.ep_pp_size > 1:
647+
dist.all_reduce(is_lora, group=self.ep_pp_group)
648+
elif not is_expert and self.pp_size > 1:
631649
dist.all_reduce(is_lora, group=self.pp_group)
632-
if is_expert and self.ep_size > 1:
633-
dist.all_reduce(is_lora, group=self.ep_group)
634650
if is_lora:
635651
assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.'
636652
if mg_mlp is None:
@@ -779,10 +795,10 @@ def _set_mlp_state(self,
779795
is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2,
780796
LoraParallelLinear) and self._is_peft_format
781797
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
782-
if self.pp_size > 1:
798+
if is_expert and self.ep_pp_size > 1:
799+
dist.all_reduce(is_lora, group=self.ep_pp_group)
800+
elif not is_expert and self.pp_size > 1:
783801
dist.all_reduce(is_lora, group=self.pp_group)
784-
if is_expert and self.ep_size > 1:
785-
dist.all_reduce(is_lora, group=self.ep_group)
786802
if is_lora:
787803
assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.'
788804
if mg_mlp is None:

0 commit comments

Comments
 (0)