11# Copyright (c) Alibaba, Inc. and its affiliates.
2+ import os
3+ from dataclasses import dataclass , field
4+ from datetime import datetime , timedelta
25from typing import Any , Dict , List , Literal , Optional
36
4- import torch
5-
67from ..base import Template
7- from ..constant import MLLMTemplateType
8+ from ..constant import LLMTemplateType , MLLMTemplateType
89from ..register import TemplateMeta , register_template
910from ..template_inputs import StdTemplateInputs
10- from ..utils import Context , findall
11- from .llm import mistral_2501_system
11+ from ..utils import Context , Prompt , findall
12+
13+ today = datetime .now ().strftime ('%Y-%m-%d' )
14+
15+ mistral_2501_system = (
16+ 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
17+ 'headquartered in Paris.\n '
18+ f'Your knowledge base was last updated on 2023-10-01. The current date is { today } .\n \n '
19+ "When you're not sure about some information, you say that you don't have the information and don't "
20+ 'make up anything.\n '
21+ "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
22+ 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
23+ '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
24+ 'Where do you travel from?")' )
25+
26+
27+ @dataclass
28+ class Mistral3TemplateMeta (TemplateMeta ):
29+ prefix : Prompt = field (default_factory = lambda : ['<s>' ])
30+ prompt : Prompt = field (default_factory = lambda : ['[INST]{{QUERY}}[/INST]' ])
31+ chat_sep : Optional [Prompt ] = field (default_factory = lambda : ['</s>' ])
32+ suffix : Prompt = field (default_factory = lambda : ['</s>' ])
33+ system_prefix : Optional [Prompt ] = field (default_factory = lambda : ['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]' ])
34+
35+
36+ register_template (Mistral3TemplateMeta (LLMTemplateType .mistral_2501 , default_system = mistral_2501_system ))
1237
1338
1439class Mistral2503Template (Template ):
@@ -28,15 +53,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
2853 labels = encoded ['labels' ]
2954 loss_scale = encoded .get ('loss_scale' , None )
3055 idx_list = findall (input_ids , self .image_token )
56+ patch_size = processor .patch_size * processor .spatial_merge_size
3157 if idx_list :
32- image_inputs = processor .image_processor (images , patch_size = processor . patch_size , return_tensors = 'pt' )
58+ image_inputs = processor .image_processor (images , patch_size = patch_size , return_tensors = 'pt' )
3359 encoded ['pixel_values' ] = image_inputs ['pixel_values' ].to (self .model_info .torch_dtype )
3460 encoded ['image_sizes' ] = image_sizes = image_inputs ['image_sizes' ]
3561
3662 def _get_new_tokens (i ):
3763 height , width = image_sizes [i ]
38- num_height_tokens = height // ( processor . patch_size * processor . spatial_merge_size )
39- num_width_tokens = width // ( processor . patch_size * processor . spatial_merge_size )
64+ num_height_tokens = height // patch_size
65+ num_width_tokens = width // patch_size
4066 replace_tokens = [[processor .image_token ] * num_width_tokens + [processor .image_break_token ]
4167 ] * num_height_tokens
4268 # Flatten list
@@ -52,15 +78,8 @@ def _get_new_tokens(i):
5278
5379
5480register_template (
55- TemplateMeta (
56- MLLMTemplateType .mistral_2503 ,
57- prefix = ['<s>' ],
58- prompt = ['[INST]{{QUERY}}[/INST]' ],
59- chat_sep = ['</s>' ],
60- suffix = ['</s>' ],
61- system_prefix = ['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]' ],
62- default_system = mistral_2501_system ,
63- template_cls = Mistral2503Template ))
81+ Mistral3TemplateMeta (
82+ MLLMTemplateType .mistral_2503 , default_system = mistral_2501_system , template_cls = Mistral2503Template ))
6483
6584devstral_small_2505_system = ( # from https://huggingface.co/mistralai/Devstral-Small-2505/blob/main/SYSTEM_PROMPT.txt
6685 'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. '
@@ -122,12 +141,27 @@ def _get_new_tokens(i):
122141 'executing a plan from the user, please don\' t try to directly work around it. Instead, propose a new '
123142 'plan and confirm with the user before proceeding.\n </TROUBLESHOOTING>' )
124143
144+ register_template (Mistral3TemplateMeta ('devstral' , default_system = devstral_small_2505_system ))
145+
146+
147+ class Mistral2506Template (Mistral2503Template ):
148+
149+ def _get_mistral_system (self ):
150+ from swift .llm import get_model_name
151+ model_dir = self .model_info .model_dir
152+ model_name = get_model_name (model_dir )
153+ file_path = os .path .join (model_dir , 'SYSTEM_PROMPT.txt' )
154+ with open (file_path , 'r' ) as file :
155+ system_prompt = file .read ()
156+ today = datetime .today ().strftime ('%Y-%m-%d' )
157+ yesterday = (datetime .today () - timedelta (days = 1 )).strftime ('%Y-%m-%d' )
158+ return system_prompt .format (name = model_name , today = today , yesterday = yesterday )
159+
160+ def _swift_encode (self , inputs : StdTemplateInputs ):
161+ if inputs .system is None :
162+ inputs .system = self ._get_mistral_system ()
163+ return super ()._swift_encode (inputs )
164+
165+
125166register_template (
126- TemplateMeta (
127- 'devstral' ,
128- prefix = ['<s>' ],
129- prompt = ['[INST]{{QUERY}}[/INST]' ], # the user query
130- chat_sep = ['</s>' ],
131- suffix = ['</s>' ],
132- system_prefix = ['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]' ], # the system prompt
133- default_system = devstral_small_2505_system ))
167+ Mistral3TemplateMeta (MLLMTemplateType .mistral_2506 , default_system = None , template_cls = Mistral2506Template ))
0 commit comments