Skip to content

Commit eef4f7f

Browse files
authored
[bugfix] fix megatron grpo local jsonl writer (#6700)
1 parent 7809cb8 commit eef4f7f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

swift/megatron/trainers/grpo_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,8 @@ def _replace_data_iterator(self, data_iterator, model):
472472
]
473473
assert len(mini_batch_data) == self.steps_per_generation
474474
self._buffered_inputs = mini_batch_data
475-
self._step += 1
476475
inputs = self._buffered_inputs[self._step % self.steps_per_generation]
476+
self._step += 1
477477
return RerunDataIterator(iter(inputs))
478478

479479
def _generate_and_score_completions(self, batch):
@@ -1241,9 +1241,9 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
12411241
reporting_metric = {**avg_metric, **custom_metrics}
12421242

12431243
# log_completions
1244-
if self.log_completions and self.is_main_process and self._step % self.steps_per_generation == 0:
1244+
if self.log_completions and self.is_main_process and (self._step - 1) % self.steps_per_generation == 0:
12451245
table = {
1246-
'gen_step': [self._step] * len(self._logs['prompt']),
1246+
'gen_step': [self._step - 1] * len(self._logs['prompt']),
12471247
'prompt': list(self._logs['prompt']),
12481248
'completion': list(self._logs['completion']),
12491249
**{k: list(v)
@@ -1432,7 +1432,7 @@ def _prepare_metrics(self):
14321432
from collections import deque
14331433
self.log_completions = args.log_completions
14341434
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
1435-
self.jsonl_writer = JsonlWriter(os.path.join(args.save, 'completions.jsonl'))
1435+
self.jsonl_writer = JsonlWriter(os.path.join(args.save, 'completions.jsonl'), write_on_rank='last')
14361436
self.init_custom_metric = False
14371437
self._logs = {
14381438
'prompt': deque(maxlen=args.generation_batch_size),

0 commit comments

Comments
 (0)