@@ -472,8 +472,8 @@ def _replace_data_iterator(self, data_iterator, model):
472472 ]
473473 assert len (mini_batch_data ) == self .steps_per_generation
474474 self ._buffered_inputs = mini_batch_data
475- self ._step += 1
476475 inputs = self ._buffered_inputs [self ._step % self .steps_per_generation ]
476+ self ._step += 1
477477 return RerunDataIterator (iter (inputs ))
478478
479479 def _generate_and_score_completions (self , batch ):
@@ -1241,9 +1241,9 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
12411241 reporting_metric = {** avg_metric , ** custom_metrics }
12421242
12431243 # log_completions
1244- if self .log_completions and self .is_main_process and self ._step % self .steps_per_generation == 0 :
1244+ if self .log_completions and self .is_main_process and ( self ._step - 1 ) % self .steps_per_generation == 0 :
12451245 table = {
1246- 'gen_step' : [self ._step ] * len (self ._logs ['prompt' ]),
1246+ 'gen_step' : [self ._step - 1 ] * len (self ._logs ['prompt' ]),
12471247 'prompt' : list (self ._logs ['prompt' ]),
12481248 'completion' : list (self ._logs ['completion' ]),
12491249 ** {k : list (v )
@@ -1432,7 +1432,7 @@ def _prepare_metrics(self):
14321432 from collections import deque
14331433 self .log_completions = args .log_completions
14341434 self .wandb_log_unique_prompts = args .wandb_log_unique_prompts
1435- self .jsonl_writer = JsonlWriter (os .path .join (args .save , 'completions.jsonl' ))
1435+ self .jsonl_writer = JsonlWriter (os .path .join (args .save , 'completions.jsonl' ), write_on_rank = 'last' )
14361436 self .init_custom_metric = False
14371437 self ._logs = {
14381438 'prompt' : deque (maxlen = args .generation_batch_size ),
0 commit comments