Skip to content

Commit d04ecfe

Browse files
authored
[bugfix] fix qwen3_omni seq_cls (#6673)
1 parent e8c0282 commit d04ecfe

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

swift/llm/model/patcher.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,15 @@ def _patch_sequence_classification(model, model_meta):
257257
lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer']
258258
llm_model = get_lm_head_model(model, model_meta, lm_heads)
259259
llm_model.num_labels = model.config.num_labels
260-
llm_model.score = nn.Linear(hidden_size, llm_model.num_labels, bias=False, dtype=llm_model.dtype)
261-
if llm_model.score.weight.device == torch.device('meta'):
262-
llm_model.score.to_empty(device='cpu')
263-
llm_model.score.weight.data.normal_(mean=0.0, std=initializer_range)
264260
for lm_head in lm_heads:
265261
if hasattr(llm_model, lm_head):
262+
hidden_size = getattr(llm_model, lm_head).in_features
266263
setattr(llm_model, lm_head, nn.Identity())
267264
break
265+
llm_model.score = nn.Linear(hidden_size, llm_model.num_labels, bias=False, dtype=llm_model.dtype)
266+
if llm_model.score.weight.device == torch.device('meta'):
267+
llm_model.score.to_empty(device='cpu')
268+
llm_model.score.weight.data.normal_(mean=0.0, std=initializer_range)
268269

269270
origin_forward = llm_model.forward
270271

0 commit comments

Comments
 (0)