1010import torch
1111import torch .nn as nn
1212import torch .nn .functional as F
13+ import trl
1314from accelerate .utils import gather_object , is_peft_model
15+ from packaging import version
1416from transformers import PreTrainedModel
1517from trl import GKDTrainer as HFGKDTrainer
1618from trl import SFTTrainer as HFSFTTrainer
2022 unwrap_model_for_generation )
2123from ..mixin import SwiftMixin
2224from .rollout_mixin import DataType , RolloutTrainerMixin
23- from .utils import identity_data_collator , patch_profiling_context , patch_profiling_decorator , prepare_deepspeed
25+ from .utils import (get_gather_if_zero3_context , identity_data_collator , patch_profiling_context ,
26+ patch_profiling_decorator , prepare_deepspeed )
2427
2528try :
2629 from liger_kernel .chunked_loss import LigerFusedLinearJSDLoss
@@ -61,10 +64,12 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
6164 self ._prepare_liger_loss ()
6265
6366 self .teacher_ds3_gather_for_generation = args .ds3_gather_for_generation
67+ self .is_teacher_ds3 = None
6468 # Initialize teacher model
6569 if self .is_deepspeed_enabled :
6670 if teacher_deepspeed_config is not None :
67- if teacher_deepspeed_config .get ('zero_optimization' , {}).get ('stage' ) != 3 :
71+ self .is_teacher_ds3 = teacher_deepspeed_config .get ('zero_optimization' , {}).get ('stage' ) == 3
72+ if not self .is_teacher_ds3 :
6873 self .teacher_ds3_gather_for_generation = False
6974 self .teacher_model = prepare_deepspeed (
7075 teacher_model , self .accelerator , deepspeed_config = teacher_deepspeed_config , training_args = args )
@@ -88,6 +93,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
8893 self .maybe_activation_offload_context = get_act_offloading_ctx_manager (model = self .model )
8994 else :
9095 self .maybe_activation_offload_context = nullcontext ()
96+ self ._trl_version_gte_0_24 = version .parse (trl .__version__ ) >= version .parse ('0.24' )
9197
9298 # Code borrowed from huggingface/trl
9399 def generate_on_policy_outputs (self , model , inputs , generation_config , pad_token_id = None ):
@@ -131,7 +137,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
131137 model_inputs = {k : v for k , v in inputs .items () if k not in {'prompt' , 'labels' }}
132138 # If generate is used, then use_logits_to_keep must be set to False.
133139 use_logits_to_keep = self .get_use_logits_to_keep (True )
134- if use_logits_to_keep :
140+ if use_logits_to_keep and not self . use_liger_gkd_loss :
135141 self .prepare_logits_to_keep (inputs )
136142 model_inputs ['logits_to_keep' ] = inputs ['logits_to_keep' ]
137143
@@ -176,17 +182,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
176182 student_head = unwrapped_student .get_output_embeddings ()
177183 teacher_head = unwrapped_teacher .get_output_embeddings ()
178184
179- # Compute liger fused JSD loss
180- loss = self .liger_jsd_loss (
181- student_input = student_hidden ,
182- student_weight = student_head .weight ,
183- teacher_input = teacher_hidden ,
184- teacher_weight = teacher_head .weight ,
185- true_labels = true_labels ,
186- student_bias = getattr (student_head , 'bias' , None ),
187- teacher_bias = getattr (teacher_head , 'bias' , None ),
188- )
189-
185+ # Prepare context managers for gathering parameters in zero3
186+ teacher_context = get_gather_if_zero3_context (self , is_zero3 = self .is_teacher_ds3 )(teacher_head .weight )
187+ student_context = get_gather_if_zero3_context (self )(student_head .weight )
188+
189+ with teacher_context , student_context :
190+ # Compute liger fused JSD loss
191+ loss = self .liger_jsd_loss (
192+ student_input = student_hidden ,
193+ student_weight = student_head .weight ,
194+ teacher_input = teacher_hidden ,
195+ teacher_weight = teacher_head .weight ,
196+ true_labels = true_labels ,
197+ student_bias = getattr (student_head , 'bias' , None ),
198+ teacher_bias = getattr (teacher_head , 'bias' , None ),
199+ )
200+ # loss / grad norm is unexpectedly large, normalize by sequence length
201+ # https://github.com/linkedin/Liger-Kernel/blob/v0.6.3/src/liger_kernel/chunked_loss/jsd_loss.py#L9-L39
202+ loss /= student_hidden .shape [1 ]
190203 # Release hidden states after loss computation
191204 del student_hidden , teacher_hidden , true_labels
192205 else :
@@ -222,7 +235,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
222235 teacher_logits = shifted_teacher_logits ,
223236 beta = self .beta ,
224237 )
225-
238+ if self ._trl_version_gte_0_24 :
239+ loss /= shifted_student_logits .shape [1 ]
226240 # Add SFT loss if enabled (common for both paths)
227241 if self .args .sft_alpha > 0 :
228242 loss = loss + self .args .sft_alpha * outputs_student .loss
0 commit comments