@@ -1314,7 +1314,7 @@ def _get_per_token_logps_and_entropies_sp(
13141314 k : v
13151315 for k , v in inputs .items () if k not in [
13161316 'logits_to_keep' , 'completion_mask' , 'ref_per_token_logps' , 'advantages' , 'old_per_token_logps' ,
1317- 'truncated_mask' , 'seq_lengths'
1317+ 'truncated_mask' , 'seq_lengths' , 'num_items_in_batch'
13181318 ]
13191319 }
13201320 sequence_parallel .prepare_inputs (inputs )
@@ -1394,7 +1394,7 @@ def _get_per_token_logps_and_entropies_single(self,
13941394 k : v
13951395 for k , v in inputs .items () if k not in [
13961396 'logits_to_keep' , 'completion_mask' , 'ref_per_token_logps' , 'advantages' , 'old_per_token_logps' ,
1397- 'truncated_mask' , 'seq_lengths'
1397+ 'truncated_mask' , 'seq_lengths' , 'num_items_in_batch'
13981398 ]
13991399 }
14001400 if 'logits_to_keep' in self .model_kwarg_keys :
@@ -1488,7 +1488,7 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
14881488 k : v
14891489 for k , v in inputs .items () if k not in [
14901490 'logits_to_keep' , 'completion_mask' , 'ref_per_token_logps' , 'advantages' , 'old_per_token_logps' ,
1491- 'truncated_mask' , 'seq_lengths'
1491+ 'truncated_mask' , 'seq_lengths' , 'num_items_in_batch'
14921492 ]
14931493 }
14941494 if 'logits_to_keep' in self .model_kwarg_keys :
0 commit comments