@@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False):
6464 self .etp_rank = mpu .get_expert_tensor_parallel_rank ()
6565 self .ep_rank = mpu .get_expert_model_parallel_rank ()
6666
67+ dp_size = dist .get_world_size () // self .etp_size // self .ep_size // self .pp_size
68+ expert_decoder_rank_generator = mpu .RankGenerator (
69+ tp = self .etp_size ,
70+ ep = self .ep_size ,
71+ dp = dp_size ,
72+ pp = self .pp_size ,
73+ cp = 1 ,
74+ order = 'tp-cp-ep-dp-pp' ,
75+ rank_offset = 0 ,
76+ )
77+ rank = dist .get_rank ()
78+ for ranks in expert_decoder_rank_generator .get_ranks ('ep-pp' ):
79+ group = mpu .create_group (
80+ ranks ,
81+ group_desc = 'EP-PP-GROUP' ,
82+ )
83+ if rank in ranks :
84+ self .ep_pp_size = self .ep_size * self .pp_size
85+ self .ep_pp_group = group
86+ self .ep_pp_rank = dist .get_rank (group )
87+
6788 def _init_meta_hf_model (self ):
6889 with torch .device ('meta' ):
6990 self .hf_model , self .processor = get_model_tokenizer (
@@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
198219 tensor = None if mg_weight is None else mg_weight .to ('cuda' )
199220 tp_size = self .etp_size if is_expert else self .tp_size
200221 tp_group = self .etp_group if is_expert else self .tp_group
222+ pp_group = self .ep_pp_group if is_expert else self .pp_group
223+ pp_size = self .ep_pp_size if is_expert else self .pp_size
224+ pp_rank = self .ep_pp_rank if is_expert else self .pp_rank
201225 if tensor is not None and tp_dim is not None and tp_size > 1 :
202226 if tp_dim == 0 :
203227 # save memory
@@ -220,34 +244,26 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
220244 tensor = torch .cat (output , dim = tp_dim )
221245 del output
222246 # pp/ep
223- for parallel_state in ['ep' , 'pp' ]:
224- if parallel_state == 'pp' and self .pp_size > 1 :
225- parallel_group = self .pp_group
226- parallel_rank = self .pp_rank
227- elif parallel_state == 'ep' and is_expert and self .ep_size > 1 :
228- parallel_group = self .ep_group
229- parallel_rank = self .ep_rank
230- else :
231- continue
232- src_rank = torch .tensor ([0 if tensor is None else parallel_rank ], dtype = torch .int64 , device = 'cuda' )
233- dist .all_reduce (src_rank , group = parallel_group )
234- src_rank = dist .get_global_rank (parallel_group , src_rank .item ())
247+ if pp_size > 1 :
248+ src_rank = torch .tensor ([0 if tensor is None else pp_rank ], dtype = torch .int64 , device = 'cuda' )
249+ dist .all_reduce (src_rank , group = pp_group )
250+ src_rank = dist .get_global_rank (pp_group , src_rank .item ())
235251 meta_data = torch .zeros (10 , dtype = torch .int64 , device = 'cuda' )
236252 dtype_mapping = {torch .float64 : 0 , torch .float32 : 1 , torch .float16 : 2 , torch .bfloat16 : 3 }
237253 dtype_mapping_r = {v : k for k , v in dtype_mapping .items ()}
238254 if tensor is None :
239- dist .broadcast (meta_data , src = src_rank , group = parallel_group )
240- if meta_data [0 ].item () > 0 :
241- shape = meta_data [1 :1 + meta_data [0 ]].tolist ()
242- dtype = dtype_mapping_r [meta_data [- 1 ].item ()]
243- tensor = torch .empty (shape , device = 'cuda' , dtype = dtype )
244- dist .broadcast (tensor , src = src_rank , group = parallel_group )
255+ dist .broadcast (meta_data , src = src_rank , group = pp_group )
256+ assert meta_data [0 ].item () > 0 , f'meta_data: { meta_data } '
257+ shape = meta_data [1 :1 + meta_data [0 ]].tolist ()
258+ dtype = dtype_mapping_r [meta_data [- 1 ].item ()]
259+ tensor = torch .empty (shape , device = 'cuda' , dtype = dtype )
260+ dist .broadcast (tensor , src = src_rank , group = pp_group )
245261 else :
246262 meta_data [0 ] = tensor .ndim
247263 meta_data [1 :1 + tensor .ndim ] = torch .tensor (tensor .shape , dtype = torch .int64 , device = 'cuda' )
248264 meta_data [- 1 ] = dtype_mapping [tensor .dtype ]
249- dist .broadcast (meta_data , src = src_rank , group = parallel_group )
250- dist .broadcast (tensor , src = src_rank , group = parallel_group )
265+ dist .broadcast (meta_data , src = src_rank , group = pp_group )
266+ dist .broadcast (tensor , src = src_rank , group = pp_group )
251267 assert tensor is not None , f'mg_key: { mg_key } '
252268 if offset :
253269 tensor = tensor + offset
@@ -273,10 +289,10 @@ def _set_state_dict(self,
273289 is_modules_to_save = isinstance (sub_module , ModulesToSaveWrapper )
274290 if not to_mcore :
275291 state = torch .tensor ([is_lora , is_modules_to_save ], dtype = torch .bool , device = 'cuda' )
276- if self .pp_size > 1 :
292+ if is_expert and self .ep_pp_size > 1 :
293+ dist .all_reduce (state , group = self .ep_pp_group )
294+ elif not is_expert and self .pp_size > 1 :
277295 dist .all_reduce (state , group = self .pp_group )
278- if is_expert and self .ep_size > 1 :
279- dist .all_reduce (state , group = self .ep_group )
280296 is_lora , is_modules_to_save = state
281297 if is_lora and self ._is_peft_format and param_key != 'layer_norm_weight' :
282298 if to_mcore :
@@ -627,10 +643,10 @@ def _set_mlp_state(self,
627643 is_lora = False if mg_mlp is None else isinstance (mg_mlp .linear_fc1 ,
628644 LoraParallelLinear ) and self ._is_peft_format
629645 is_lora = torch .tensor ([is_lora ], dtype = torch .bool , device = 'cuda' )
630- if self .pp_size > 1 :
646+ if is_expert and self .ep_pp_size > 1 :
647+ dist .all_reduce (is_lora , group = self .ep_pp_group )
648+ elif not is_expert and self .pp_size > 1 :
631649 dist .all_reduce (is_lora , group = self .pp_group )
632- if is_expert and self .ep_size > 1 :
633- dist .all_reduce (is_lora , group = self .ep_group )
634650 if is_lora :
635651 assert not hf_grouped , 'Currently, hf_grouped with LoRA is not supported.'
636652 if mg_mlp is None :
@@ -779,10 +795,10 @@ def _set_mlp_state(self,
779795 is_lora = False if mg_mlp is None else isinstance (mg_mlp .linear_fc2 ,
780796 LoraParallelLinear ) and self ._is_peft_format
781797 is_lora = torch .tensor ([is_lora ], dtype = torch .bool , device = 'cuda' )
782- if self .pp_size > 1 :
798+ if is_expert and self .ep_pp_size > 1 :
799+ dist .all_reduce (is_lora , group = self .ep_pp_group )
800+ elif not is_expert and self .pp_size > 1 :
783801 dist .all_reduce (is_lora , group = self .pp_group )
784- if is_expert and self .ep_size > 1 :
785- dist .all_reduce (is_lora , group = self .ep_group )
786802 if is_lora :
787803 assert not hf_grouped , 'Currently, hf_grouped with LoRA is not supported.'
788804 if mg_mlp is None :
0 commit comments