File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -257,14 +257,15 @@ def _patch_sequence_classification(model, model_meta):
257257 lm_heads = ['lm_head' , 'output' , 'embed_out' , 'output_layer' ]
258258 llm_model = get_lm_head_model (model , model_meta , lm_heads )
259259 llm_model .num_labels = model .config .num_labels
260- llm_model .score = nn .Linear (hidden_size , llm_model .num_labels , bias = False , dtype = llm_model .dtype )
261- if llm_model .score .weight .device == torch .device ('meta' ):
262- llm_model .score .to_empty (device = 'cpu' )
263- llm_model .score .weight .data .normal_ (mean = 0.0 , std = initializer_range )
264260 for lm_head in lm_heads :
265261 if hasattr (llm_model , lm_head ):
262+ hidden_size = getattr (llm_model , lm_head ).in_features
266263 setattr (llm_model , lm_head , nn .Identity ())
267264 break
265+ llm_model .score = nn .Linear (hidden_size , llm_model .num_labels , bias = False , dtype = llm_model .dtype )
266+ if llm_model .score .weight .device == torch .device ('meta' ):
267+ llm_model .score .to_empty (device = 'cpu' )
268+ llm_model .score .weight .data .normal_ (mean = 0.0 , std = initializer_range )
268269
269270 origin_forward = llm_model .forward
270271
You can’t perform that action at this time.
0 commit comments