Skip to content

Commit 6cfaeaa

Browse files
authored
[bugfix] fix type_type=rm eval trl>=0.25 (#6701)
1 parent eef4f7f commit 6cfaeaa

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

swift/trainers/rlhf_trainer/reward_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import pandas as pd
77
import torch
88
import torch.nn as nn
9+
import trl
910
from accelerate.utils import gather_object
11+
from packaging import version
1012
from transformers import PreTrainedModel
1113
from trl import RewardTrainer as HFRewardTrainer
1214
from trl.trainer.utils import print_rich_table
@@ -33,6 +35,10 @@ def __init__(self, *args, **kwargs):
3335
except ImportError:
3436
self.maybe_activation_offload_context = nullcontext()
3537
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
38+
if version.parse(trl.__version__) >= version.parse('0.24'):
39+
# During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty.
40+
self.can_return_loss = True
41+
self.label_names = []
3642

3743
def compute_loss(self,
3844
model: Union[PreTrainedModel, nn.Module],

0 commit comments

Comments
 (0)