Skip to content

Commit 6232492

Browse files
authored
[bugfix] Fix GKD with TRL >= 0.24 & GKD Liger (#6663)
1 parent 5423f95 commit 6232492

File tree

7 files changed

+38
-22
lines changed

7 files changed

+38
-22
lines changed

docs/source/BestPractices/Qwen3-Best-Practice.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ ms-swift 支持 DPO、GRPO、DAPO、PPO、KTO、GKD 等 RLHF 方法。本章将
254254

255255
除了安装上述介绍的 ms-swift 相关依赖项外,还需要安装以下依赖项:
256256
```
257-
pip install "math_verify==0.5.2"
257+
pip install "math_verify"
258258
pip install vllm==0.8.5.post1
259259
```
260260

docs/source_en/BestPractices/Qwen3-Best-Practice.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ ms-swift supports RLHF methods such as DPO, GRPO, DAPO, PPO, KTO, and GKD. This
258258
In addition to installing the dependencies related to ms-swift mentioned above, you also need to install the following:
259259

260260
```shell
261-
pip install "math_verify==0.5.2"
261+
pip install "math_verify"
262262
pip install vllm==0.8.5.post1
263263
```
264264

examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
try:
1919
from math_verify import parse, verify
2020
except ImportError as e:
21-
raise ImportError('please install math_verify by `pip install math_verify==0.5.2`') from e
21+
raise ImportError('please install math_verify by `pip install math_verify`') from e
2222
"""
2323
3 dataset file
2424
1. data_v0.8_visual_toolbox_v2.parquet: data_source == 'chart' (vl_agent.compute_score)

requirements/install_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pip install git+https://github.com/modelscope/ms-swift.git#egg=ms-swift[all]
99
pip install timm "deepspeed<0.18" -U
1010
pip install qwen_vl_utils qwen_omni_utils keye_vl_utils -U
1111
pip install decord librosa icecream soundfile -U
12-
pip install liger_kernel nvitop pre-commit math_verify==0.5.2 py-spy wandb swanlab -U
12+
pip install liger_kernel nvitop pre-commit math_verify py-spy wandb swanlab -U
1313
# flash-attn: https://github.com/Dao-AILab/flash-attention/releases

swift/plugin/orm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(self):
239239
import importlib.util
240240
assert importlib.util.find_spec('math_verify') is not None, (
241241
'The math_verify package is required but not installed. '
242-
"Please install it using 'pip install math_verify==0.5.2'.")
242+
"Please install it using 'pip install math_verify'.")
243243

244244
def __call__(self, completions, solution, **kwargs) -> List[float]:
245245
from latex2sympy2_extended import NormalizationConfig

swift/trainers/rlhf_trainer/gkd_trainer.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import torch
1111
import torch.nn as nn
1212
import torch.nn.functional as F
13+
import trl
1314
from accelerate.utils import gather_object, is_peft_model
15+
from packaging import version
1416
from transformers import PreTrainedModel
1517
from trl import GKDTrainer as HFGKDTrainer
1618
from trl import SFTTrainer as HFSFTTrainer
@@ -20,7 +22,8 @@
2022
unwrap_model_for_generation)
2123
from ..mixin import SwiftMixin
2224
from .rollout_mixin import DataType, RolloutTrainerMixin
23-
from .utils import identity_data_collator, patch_profiling_context, patch_profiling_decorator, prepare_deepspeed
25+
from .utils import (get_gather_if_zero3_context, identity_data_collator, patch_profiling_context,
26+
patch_profiling_decorator, prepare_deepspeed)
2427

2528
try:
2629
from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss
@@ -61,10 +64,12 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
6164
self._prepare_liger_loss()
6265

6366
self.teacher_ds3_gather_for_generation = args.ds3_gather_for_generation
67+
self.is_teacher_ds3 = None
6468
# Initialize teacher model
6569
if self.is_deepspeed_enabled:
6670
if teacher_deepspeed_config is not None:
67-
if teacher_deepspeed_config.get('zero_optimization', {}).get('stage') != 3:
71+
self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3
72+
if not self.is_teacher_ds3:
6873
self.teacher_ds3_gather_for_generation = False
6974
self.teacher_model = prepare_deepspeed(
7075
teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args)
@@ -88,6 +93,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
8893
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
8994
else:
9095
self.maybe_activation_offload_context = nullcontext()
96+
self._trl_version_gte_0_24 = version.parse(trl.__version__) >= version.parse('0.24')
9197

9298
# Code borrowed from huggingface/trl
9399
def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None):
@@ -131,7 +137,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
131137
model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}}
132138
# If generate is used, then use_logits_to_keep must be set to False.
133139
use_logits_to_keep = self.get_use_logits_to_keep(True)
134-
if use_logits_to_keep:
140+
if use_logits_to_keep and not self.use_liger_gkd_loss:
135141
self.prepare_logits_to_keep(inputs)
136142
model_inputs['logits_to_keep'] = inputs['logits_to_keep']
137143

@@ -176,17 +182,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
176182
student_head = unwrapped_student.get_output_embeddings()
177183
teacher_head = unwrapped_teacher.get_output_embeddings()
178184

179-
# Compute liger fused JSD loss
180-
loss = self.liger_jsd_loss(
181-
student_input=student_hidden,
182-
student_weight=student_head.weight,
183-
teacher_input=teacher_hidden,
184-
teacher_weight=teacher_head.weight,
185-
true_labels=true_labels,
186-
student_bias=getattr(student_head, 'bias', None),
187-
teacher_bias=getattr(teacher_head, 'bias', None),
188-
)
189-
185+
# Prepare context managers for gathering parameters in zero3
186+
teacher_context = get_gather_if_zero3_context(self, is_zero3=self.is_teacher_ds3)(teacher_head.weight)
187+
student_context = get_gather_if_zero3_context(self)(student_head.weight)
188+
189+
with teacher_context, student_context:
190+
# Compute liger fused JSD loss
191+
loss = self.liger_jsd_loss(
192+
student_input=student_hidden,
193+
student_weight=student_head.weight,
194+
teacher_input=teacher_hidden,
195+
teacher_weight=teacher_head.weight,
196+
true_labels=true_labels,
197+
student_bias=getattr(student_head, 'bias', None),
198+
teacher_bias=getattr(teacher_head, 'bias', None),
199+
)
200+
# loss / grad norm is unexpectedly large, normalize by sequence length
201+
# https://github.com/linkedin/Liger-Kernel/blob/v0.6.3/src/liger_kernel/chunked_loss/jsd_loss.py#L9-L39
202+
loss /= student_hidden.shape[1]
190203
# Release hidden states after loss computation
191204
del student_hidden, teacher_hidden, true_labels
192205
else:
@@ -222,7 +235,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
222235
teacher_logits=shifted_teacher_logits,
223236
beta=self.beta,
224237
)
225-
238+
if self._trl_version_gte_0_24:
239+
loss /= shifted_student_logits.shape[1]
226240
# Add SFT loss if enabled (common for both paths)
227241
if self.args.sft_alpha > 0:
228242
loss = loss + self.args.sft_alpha * outputs_student.loss

swift/trainers/rlhf_trainer/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,11 @@ def patched_len(self) -> int:
605605
RepeatSampler.old_len_func = origin_len_func
606606

607607

608-
def get_gather_if_zero3_context(trainer):
608+
def get_gather_if_zero3_context(trainer, is_zero3: Optional[bool] = None):
609609
deepspeed_plugin = trainer.accelerator.state.deepspeed_plugin
610-
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
610+
zero_stage_3 = is_zero3 if is_zero3 is not None else (deepspeed_plugin is not None
611+
and deepspeed_plugin.zero_stage == 3)
612+
611613
if zero_stage_3:
612614
import deepspeed
613615
gather_if_zero3 = deepspeed.zero.GatheredParameters

0 commit comments

Comments
 (0)