Skip to content

Commit 7e1d218

Browse files
authored
[model] support mistral 2506 (#6624)
1 parent ec8efe7 commit 7e1d218

File tree

9 files changed

+108
-59
lines changed

9 files changed

+108
-59
lines changed

docs/source/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,7 @@
10241024
|[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
10251025
|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)|
10261026
|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|
1027+
|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|
10271028
|[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)|
10281029
|[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)|
10291030

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,7 @@ The table below introduces the models integrated with ms-swift:
10241024
|[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
10251025
|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)|
10261026
|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|
1027+
|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|
10271028
|[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)|
10281029
|[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)|
10291030

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ class MLLMModelType:
274274
gemma3_vision = 'gemma3_vision'
275275
gemma3n = 'gemma3n'
276276
mistral_2503 = 'mistral_2503'
277+
mistral_2506 = 'mistral_2506'
277278
paddle_ocr = 'paddle_ocr'
278279

279280

swift/llm/model/model/mistral.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
32
from typing import Any, Dict
43

5-
from transformers import AutoTokenizer
4+
from transformers import AutoProcessor, AutoTokenizer
65

76
from swift.llm import TemplateType
87
from ..constant import LLMModelType, MLLMModelType
@@ -130,12 +129,7 @@ def get_model_tokenizer_mistral_2503(model_dir: str,
130129
model_kwargs: Dict[str, Any],
131130
load_model: bool = True,
132131
**kwargs):
133-
try:
134-
from transformers import Mistral3ForConditionalGeneration
135-
except ImportError:
136-
raise ImportError('Please install Mistral3ForConditionalGeneration by running '
137-
'`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`')
138-
132+
from transformers import Mistral3ForConditionalGeneration
139133
kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration
140134
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
141135

@@ -184,4 +178,35 @@ def get_model_tokenizer_devstral_2505(model_dir: str,
184178
architectures=['Mistral3ForConditionalGeneration'],
185179
model_arch=ModelArch.llava_hf,
186180
requires=['transformers>=4.49'],
187-
), )
181+
))
182+
183+
184+
def get_model_tokenizer_mistral_2506(model_dir: str,
185+
model_info: ModelInfo,
186+
model_kwargs: Dict[str, Any],
187+
load_model: bool = True,
188+
**kwargs):
189+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
190+
from transformers import Mistral3ForConditionalGeneration
191+
tokenizer_dir = safe_snapshot_download('mistralai/Mistral-Small-3.1-24B-Instruct-2503', download_model=False)
192+
processor = AutoProcessor.from_pretrained(tokenizer_dir)
193+
kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration
194+
kwargs['tokenizer'] = processor.tokenizer
195+
model, _ = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
196+
return model, processor
197+
198+
199+
register_model(
200+
ModelMeta(
201+
MLLMModelType.mistral_2506,
202+
[
203+
ModelGroup([
204+
Model('mistralai/Mistral-Small-3.2-24B-Instruct-2506', 'mistralai/Mistral-Small-3.2-24B-Instruct-2506'),
205+
]),
206+
],
207+
TemplateType.mistral_2506,
208+
get_model_tokenizer_mistral_2506,
209+
architectures=['Mistral3ForConditionalGeneration'],
210+
model_arch=ModelArch.llava_hf,
211+
requires=['transformers>=4.49'],
212+
))

swift/llm/model/model_arch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class MLLMModelArch:
8181
megrez_omni = 'megrez_omni'
8282
valley = 'valley'
8383
gemma3n = 'gemma3n'
84-
mistral_2503 = 'mistral_2503'
8584
keye_vl = 'keye_vl'
8685

8786
midashenglm = 'midashenglm'

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class MLLMTemplateType:
229229
gemma3_vision = 'gemma3_vision'
230230
gemma3n = 'gemma3n'
231231
mistral_2503 = 'mistral_2503'
232+
mistral_2506 = 'mistral_2506'
232233
paddle_ocr = 'paddle_ocr'
233234

234235

swift/llm/template/template/llm.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -119,29 +119,6 @@ def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None:
119119
chat_sep=['</s>[INST] '],
120120
suffix=['</s>']))
121121

122-
today = datetime.now().strftime('%Y-%m-%d')
123-
124-
mistral_2501_system = (
125-
'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
126-
'headquartered in Paris.\n'
127-
f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n'
128-
"When you're not sure about some information, you say that you don't have the information and don't "
129-
'make up anything.\n'
130-
"If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
131-
'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
132-
'"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
133-
'Where do you travel from?")')
134-
135-
register_template(
136-
TemplateMeta(
137-
LLMTemplateType.mistral_2501,
138-
prefix=['<s>'],
139-
prompt=['[INST]{{QUERY}}[/INST]'],
140-
chat_sep=['</s>'],
141-
suffix=['</s>'],
142-
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
143-
default_system=mistral_2501_system))
144-
145122
register_template(
146123
TemplateMeta(
147124
LLMTemplateType.xverse,

swift/llm/template/template/mistral.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,39 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import os
3+
from dataclasses import dataclass, field
4+
from datetime import datetime, timedelta
25
from typing import Any, Dict, List, Literal, Optional
36

4-
import torch
5-
67
from ..base import Template
7-
from ..constant import MLLMTemplateType
8+
from ..constant import LLMTemplateType, MLLMTemplateType
89
from ..register import TemplateMeta, register_template
910
from ..template_inputs import StdTemplateInputs
10-
from ..utils import Context, findall
11-
from .llm import mistral_2501_system
11+
from ..utils import Context, Prompt, findall
12+
13+
today = datetime.now().strftime('%Y-%m-%d')
14+
15+
mistral_2501_system = (
16+
'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
17+
'headquartered in Paris.\n'
18+
f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n'
19+
"When you're not sure about some information, you say that you don't have the information and don't "
20+
'make up anything.\n'
21+
"If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
22+
'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
23+
'"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
24+
'Where do you travel from?")')
25+
26+
27+
@dataclass
28+
class Mistral3TemplateMeta(TemplateMeta):
29+
prefix: Prompt = field(default_factory=lambda: ['<s>'])
30+
prompt: Prompt = field(default_factory=lambda: ['[INST]{{QUERY}}[/INST]'])
31+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['</s>'])
32+
suffix: Prompt = field(default_factory=lambda: ['</s>'])
33+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'])
34+
35+
36+
register_template(Mistral3TemplateMeta(LLMTemplateType.mistral_2501, default_system=mistral_2501_system))
1237

1338

1439
class Mistral2503Template(Template):
@@ -28,15 +53,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
2853
labels = encoded['labels']
2954
loss_scale = encoded.get('loss_scale', None)
3055
idx_list = findall(input_ids, self.image_token)
56+
patch_size = processor.patch_size * processor.spatial_merge_size
3157
if idx_list:
32-
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
58+
image_inputs = processor.image_processor(images, patch_size=patch_size, return_tensors='pt')
3359
encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype)
3460
encoded['image_sizes'] = image_sizes = image_inputs['image_sizes']
3561

3662
def _get_new_tokens(i):
3763
height, width = image_sizes[i]
38-
num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size)
39-
num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size)
64+
num_height_tokens = height // patch_size
65+
num_width_tokens = width // patch_size
4066
replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token]
4167
] * num_height_tokens
4268
# Flatten list
@@ -52,15 +78,8 @@ def _get_new_tokens(i):
5278

5379

5480
register_template(
55-
TemplateMeta(
56-
MLLMTemplateType.mistral_2503,
57-
prefix=['<s>'],
58-
prompt=['[INST]{{QUERY}}[/INST]'],
59-
chat_sep=['</s>'],
60-
suffix=['</s>'],
61-
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
62-
default_system=mistral_2501_system,
63-
template_cls=Mistral2503Template))
81+
Mistral3TemplateMeta(
82+
MLLMTemplateType.mistral_2503, default_system=mistral_2501_system, template_cls=Mistral2503Template))
6483

6584
devstral_small_2505_system = ( # from https://huggingface.co/mistralai/Devstral-Small-2505/blob/main/SYSTEM_PROMPT.txt
6685
'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. '
@@ -122,12 +141,27 @@ def _get_new_tokens(i):
122141
'executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new '
123142
'plan and confirm with the user before proceeding.\n</TROUBLESHOOTING>')
124143

144+
register_template(Mistral3TemplateMeta('devstral', default_system=devstral_small_2505_system))
145+
146+
147+
class Mistral2506Template(Mistral2503Template):
148+
149+
def _get_mistral_system(self):
150+
from swift.llm import get_model_name
151+
model_dir = self.model_info.model_dir
152+
model_name = get_model_name(model_dir)
153+
file_path = os.path.join(model_dir, 'SYSTEM_PROMPT.txt')
154+
with open(file_path, 'r') as file:
155+
system_prompt = file.read()
156+
today = datetime.today().strftime('%Y-%m-%d')
157+
yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')
158+
return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
159+
160+
def _swift_encode(self, inputs: StdTemplateInputs):
161+
if inputs.system is None:
162+
inputs.system = self._get_mistral_system()
163+
return super()._swift_encode(inputs)
164+
165+
125166
register_template(
126-
TemplateMeta(
127-
'devstral',
128-
prefix=['<s>'],
129-
prompt=['[INST]{{QUERY}}[/INST]'], # the user query
130-
chat_sep=['</s>'],
131-
suffix=['</s>'],
132-
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], # the system prompt
133-
default_system=devstral_small_2505_system))
167+
Mistral3TemplateMeta(MLLMTemplateType.mistral_2506, default_system=None, template_cls=Mistral2506Template))

tests/test_align/test_template/test_vision.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,15 @@ def test_ernie_vl_thinking():
10921092
assert response == '\n<think>\n' + response2
10931093

10941094

1095+
def test_mistral_2506():
1096+
pt_engine = PtEngine('mistralai/Mistral-Small-3.2-24B-Instruct-2506')
1097+
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': 'describe the image.'}])
1098+
assert response[:200] == (
1099+
'The image features a close-up of a kitten with striking blue eyes. The kitten has a soft, '
1100+
'fluffy coat with a mix of white, gray, and brown fur. Its fur pattern includes distinct '
1101+
'stripes, particularly ')
1102+
1103+
10951104
if __name__ == '__main__':
10961105
from swift.llm import PtEngine, RequestConfig
10971106
from swift.utils import get_logger, seed_everything
@@ -1168,4 +1177,5 @@ def test_ernie_vl_thinking():
11681177
# test_llava_onevision1_5()
11691178
# test_paddle_ocr()
11701179
# test_ernie_vl()
1171-
test_ernie_vl_thinking()
1180+
# test_ernie_vl_thinking()
1181+
test_mistral_2506()

0 commit comments

Comments
 (0)