Skip to content

Commit 032993c

Browse files
authored
[bugfix] fix grpo 'num_iterms_in_batch' in forward kwargs (#6717)
1 parent 57cd99f commit 032993c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def _get_per_token_logps_and_entropies_sp(
13141314
k: v
13151315
for k, v in inputs.items() if k not in [
13161316
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
1317-
'truncated_mask', 'seq_lengths'
1317+
'truncated_mask', 'seq_lengths', 'num_items_in_batch'
13181318
]
13191319
}
13201320
sequence_parallel.prepare_inputs(inputs)
@@ -1394,7 +1394,7 @@ def _get_per_token_logps_and_entropies_single(self,
13941394
k: v
13951395
for k, v in inputs.items() if k not in [
13961396
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
1397-
'truncated_mask', 'seq_lengths'
1397+
'truncated_mask', 'seq_lengths', 'num_items_in_batch'
13981398
]
13991399
}
14001400
if 'logits_to_keep' in self.model_kwarg_keys:
@@ -1488,7 +1488,7 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
14881488
k: v
14891489
for k, v in inputs.items() if k not in [
14901490
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
1491-
'truncated_mask', 'seq_lengths'
1491+
'truncated_mask', 'seq_lengths', 'num_items_in_batch'
14921492
]
14931493
}
14941494
if 'logits_to_keep' in self.model_kwarg_keys:

0 commit comments

Comments
 (0)