Skip to content

Commit d0368be

Browse files
authored
[bugfix] memory log is missing on Ascend NPU (#6647)
1 parent 3068633 commit d0368be

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

swift/trainers/callback.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,22 @@
33
import os
44
import time
55

6-
import torch
76
from tqdm import tqdm
87
from transformers import trainer
98
from transformers.trainer_callback import (DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerControl,
109
TrainerState)
1110
from transformers.trainer_utils import IntervalStrategy, has_length
12-
from transformers.utils import is_torch_npu_available
1311

1412
from swift.utils import append_to_jsonl, format_time, get_device_count, get_logger, is_mp, is_pai_training_job
13+
from swift.utils.torch_utils import get_torch_device
1514
from .arguments import TrainingArguments
1615

1716
logger = get_logger()
1817

1918

20-
def get_max_cuda_memory() -> float:
19+
def get_max_reserved_memory() -> float:
2120
devices = list(range(get_device_count())) if is_mp() else [None]
22-
mems = [torch.cuda.max_memory_reserved(device=device) for device in devices]
21+
mems = [get_torch_device().max_memory_reserved(device=device) for device in devices]
2322
return sum(mems) / 1024**3
2423

2524

@@ -34,9 +33,8 @@ def add_train_message(logs, state, start_time) -> None:
3433
for k, v in logs.items():
3534
if isinstance(v, float):
3635
logs[k] = round(logs[k], 8)
37-
state.max_memory = max(getattr(state, 'max_memory', 0), get_max_cuda_memory())
38-
if not is_torch_npu_available():
39-
logs['memory(GiB)'] = round(state.max_memory, 2)
36+
state.max_memory = max(getattr(state, 'max_memory', 0), get_max_reserved_memory())
37+
logs['memory(GiB)'] = round(state.max_memory, 2)
4038

4139
logs['train_speed(iter/s)'] = round(state.global_step / elapsed, 6)
4240

0 commit comments

Comments
 (0)