Skip to content

Commit 4da83ad

Browse files
authored
[bugfix] fix train_type full freeze_llm (#6651)
1 parent 0b6aee8 commit 4da83ad

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

swift/llm/model/model_arch.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
337337
register_model_arch(
338338
MultiModelKeys(
339339
MLLMModelArch.llava_hf,
340-
language_model='model.language_model',
340+
language_model=['model.language_model', 'lm_head'],
341341
aligner='model.multi_modal_projector',
342342
vision_tower='model.vision_tower',
343343
))
@@ -362,7 +362,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
362362
register_model_arch(
363363
MultiModelKeys(
364364
MLLMModelArch.llava_next_video_hf,
365-
language_model='model.language_model',
365+
language_model=['model.language_model', 'lm_head'],
366366
aligner=['model.multi_modal_projector'],
367367
vision_tower='model.vision_tower'))
368368
else:
@@ -400,7 +400,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
400400
register_model_arch(
401401
MultiModelKeys(
402402
MLLMModelArch.interns1,
403-
language_model='model.language_model',
403+
language_model=['model.language_model', 'lm_head'],
404404
aligner='model.multi_modal_projector',
405405
vision_tower='model.vision_tower',
406406
))
@@ -521,31 +521,31 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
521521
register_model_arch(
522522
MultiModelKeys(
523523
MLLMModelArch.qwen2_vl,
524-
language_model='model.language_model',
524+
language_model=['model.language_model', 'lm_head'],
525525
aligner='model.visual.merger',
526526
vision_tower='model.visual',
527527
))
528528
else:
529529
register_model_arch(
530530
MultiModelKeys(
531531
MLLMModelArch.qwen2_vl,
532-
language_model='model',
532+
language_model=['model', 'lm_head'],
533533
aligner='visual.merger',
534534
vision_tower='visual',
535535
))
536536

537537
register_model_arch(
538538
MultiModelKeys(
539539
MLLMModelArch.qwen3_vl,
540-
language_model='model.language_model',
540+
language_model=['model.language_model', 'lm_head'],
541541
aligner=['model.visual.merger', 'model.visual.deepstack_merger_list'],
542542
vision_tower='model.visual',
543543
))
544544

545545
register_model_arch(
546546
MultiModelKeys(
547547
MLLMModelArch.qwen2_5_omni,
548-
language_model='thinker.model',
548+
language_model=['thinker.model', 'thinker.lm_head'],
549549
vision_tower=['thinker.audio_tower', 'thinker.visual'],
550550
aligner=['thinker.audio_tower.proj', 'thinker.visual.merger'],
551551
generator=['talker', 'token2wav'],
@@ -554,7 +554,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
554554
register_model_arch(
555555
MultiModelKeys(
556556
MLLMModelArch.qwen3_omni,
557-
language_model='thinker.model',
557+
language_model=['thinker.model', 'thinker.lm_head'],
558558
vision_tower=['thinker.audio_tower', 'thinker.visual'],
559559
aligner=[
560560
'thinker.audio_tower.proj1', 'thinker.audio_tower.proj2', 'thinker.visual.merger',
@@ -574,7 +574,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
574574
register_model_arch(
575575
MultiModelKeys(
576576
MLLMModelArch.step_audio2_mini,
577-
language_model='model',
577+
language_model=['model', 'lm_head'],
578578
aligner=['adapter'],
579579
vision_tower=['encoder'],
580580
))
@@ -589,7 +589,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
589589
register_model_arch(
590590
MultiModelKeys(
591591
MLLMModelArch.glm4_1v,
592-
language_model='model.language_model',
592+
language_model=['model.language_model', 'lm_head'],
593593
aligner='model.visual.merger',
594594
vision_tower='model.visual',
595595
))
@@ -622,7 +622,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
622622
register_model_arch(
623623
MultiModelKeys(
624624
MLLMModelArch.ernie_vl,
625-
language_model='model',
625+
language_model=['model', 'lm_head'],
626626
aligner='model.resampler_model',
627627
vision_tower='vision_model',
628628
))
@@ -631,7 +631,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
631631
register_model_arch(
632632
MultiModelKeys(
633633
MLLMModelArch.llama3_2_vision,
634-
language_model='model.language_model',
634+
language_model=['model.language_model', 'lm_head'],
635635
aligner='model.multi_modal_projector',
636636
vision_tower='model.vision_model',
637637
))
@@ -696,15 +696,15 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
696696
register_model_arch(
697697
MultiModelKeys(
698698
MLLMModelArch.gemma3n,
699-
language_model='model.language_model',
699+
language_model=['model.language_model', 'lm_head'],
700700
aligner=['model.embed_vision', 'model.embed_audio'],
701701
vision_tower=['model.vision_tower', 'model.audio_tower'],
702702
))
703703

704704
register_model_arch(
705705
MultiModelKeys(
706706
MLLMModelArch.keye_vl,
707-
language_model='model',
707+
language_model=['model', 'lm_head'],
708708
aligner='mlp_AR',
709709
vision_tower='visual',
710710
))
@@ -717,7 +717,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
717717
register_model_arch(
718718
MultiModelKeys(
719719
MLLMModelArch.llava_onevision1_5,
720-
language_model='model.language_model',
720+
language_model=['model.language_model', 'lm_head'],
721721
aligner='model.visual.merger',
722722
vision_tower='model.visual',
723723
))

swift/llm/train/tuner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,16 @@ def get_multimodal_target_regex(
111111
res = []
112112
for module in modules:
113113
rejected_modules = []
114-
if not freeze_vit:
114+
if not freeze_vit or not freeze_llm:
115115
for aligner in model_arch.aligner:
116116
if aligner.startswith(f'{module}.'):
117117
rejected_modules.append(aligner)
118118

119119
sub_module = deep_getattr(model, module)
120-
target_modules = find_all_linears(sub_module, model_arch, extra_layers)
120+
if isinstance(sub_module, nn.Linear) and module.endswith('lm_head'):
121+
target_modules = []
122+
else:
123+
target_modules = find_all_linears(sub_module, model_arch, extra_layers)
121124
if exclude_router and model.model_info.is_moe_model:
122125
target_modules = [tm for tm in target_modules if tm not in {'gate'}]
123126
if not target_modules:

swift/megatron/trainers/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,8 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
111111
if cp_size > 1:
112112
args = get_args()
113113
keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale']
114-
if args.is_multimodal:
115-
keys.append('decoder_input')
116-
else:
114+
if not args.is_multimodal:
115+
# Multimodal models will handle CP in input_embeds.
117116
keys.append('input_ids')
118117

119118
packed_seq_params = batch.get('packed_seq_params')

0 commit comments

Comments
 (0)