Skip to content

Commit 35aa111

Browse files
authored
[algo] support cispo algorithm (#6572)
1 parent 1174780 commit 35aa111

File tree

15 files changed

+566
-85
lines changed

15 files changed

+566
-85
lines changed

docs/source/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ reward模型参数将在PPO、GRPO中使用。
561561
- dataset_shuffle: 是否对dataset进行随机操作,默认为True。
562562
- truncation_strategy: 对输入长度超过 `max_length`的处理方式,支持`delete``left`,代表删除、左侧裁剪,默认为`left`, 注意对于多模态模型,
563563
左裁剪可能会裁剪掉多模态token导致模型前向报错shape mismatch。使用`delete`方式,对于超长数据和编码失败的样例会在原数据集中重采样其他数据作为补充。
564-
- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)
564+
- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo'], 默认为'grpo', 具体参考[文档](./GRPO/DeveloperGuide/loss_types.md)
565565
- log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb/swanlab` 使用。默认为False。
566566
- 提示:若没有设置`--report_to wandb/swanlab`,则会在checkpoint中创建`completions.jsonl`来存储生成内容。
567567
- use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Clipped Importance Sampling Policy Optimization (CISPO)
2+
3+
**版本依赖**:ms-swift>=3.11
4+
5+
Clipped Importance Sampling Policy Optimization (CISPO) 是 [MiniMax-M1](https://arxiv.org/abs/2506.13585) 论文中提出的一种强化学习算法。相比GRPO(Group Relative Policy Optimization)算法,CISPO 对重要性采样权重(importance sampling weights)本身进行裁剪。
6+
7+
## 算法原理
8+
为便于理解,我们基于 GRPO 算法进行对比说明。
9+
10+
GRPO通过裁剪策略比率来限制策略更新幅度,其损失函数为:
11+
12+
$$
13+
\mathcal{L}_{\text{GRPO}}(\theta) = -\mathbb{E}\left[\min\left(r_t(\theta) \cdot \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot \hat{A}_t\right)\right]
14+
$$
15+
16+
其中 $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ 是重要性采样比。
17+
18+
在处理长推理链条时,这种裁剪方式可能导致以下问题:
19+
20+
**关键 Token 的梯度被抑制**:在复杂推理任务中,某些关键的低概率 token(如 *However, Recheck, Wait, Aha*)对于触发深度思考和推理纠错至关重要。这些 token 在旧策略 $\pi_{\theta_{\text{old}}}$ 中概率较低,当新策略试图提高其概率时,会导致较大的策略比率 $r_t(\theta)$,GRPO 的裁剪机制会将这些 token 丢弃。
21+
22+
23+
### CISPO 的解决方案
24+
25+
CISPO 的核心思想是:裁剪重要性采样权重,保留梯度更新。具体来说,CISPO 的损失函数为:
26+
27+
$$
28+
\mathcal{L}_{\text{CISPO}}(\theta) = -\mathbb{E}\left[\text{detach}\left(\min(r_t(\theta), \epsilon_{\text{high}})\right) \cdot \hat{A}_t \cdot \log \pi_\theta(a_t|s_t)\right]
29+
$$
30+
31+
其中 $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ 是重要性采样比。
32+
33+
**关键机制**
34+
- 对重要性采样权重进行裁剪:$\min(r_t(\theta), \epsilon_{\text{high}})$
35+
- **detach 操作**:裁剪后的权重不参与梯度计算,作为常数系数
36+
- 梯度来自 $\log \pi_\theta(a_t|s_t)$ 项,保证所有 token 都有梯度贡献
37+
38+
39+
## 实现细节
40+
CISPO 的伪代码实现如下:
41+
42+
```python
43+
log_ratio = per_token_logps - old_per_token_logps
44+
importance_weights = torch.exp(log_ratio) # r_t(θ) = π_θ / π_θ_old
45+
46+
clamped_ratios = torch.clamp(importance_weights, max=epsilon_high).detach()
47+
48+
per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps
49+
```
50+
51+
## 参数设置
52+
53+
我们可以基于 `GRPOTrainer`,通过设置以下参数实现 CISPO 训练:
54+
55+
```bash
56+
--loss_type cispo
57+
--epsilon_high 5.0
58+
```
59+
60+
> 相比其他算法, cispo 的 epsilon_high 一般取值较大,minimax论文中未给出具体的参数设置,这里的值参考论文[ScaleRL](https://arxiv.org/pdf/2510.13786)的实验设置
61+
62+
其他训练参数参考 [GRPO参数文档](../../Command-line-parameters.md#grpo参数)

docs/source/Instruction/GRPO/AdvancedResearch/DAPO.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ DAPO 使用token级归一化,避免了回答长度在损失计算上的偏差
4949

5050
使用参数
5151

52-
- loss_type bnpo 来使用token级归一化
52+
- loss_type bnpo/dapo 来使用token级归一化
5353

54+
> loss_type 计算公式可参考[文档](../DeveloperGuide/loss_types.md)
5455
5556
## Overlong Filtering
5657
DAPO 认为被强制截断的回复的奖励噪声较大,可能会导致模型难以区分质量问题和长度问题。为此,DAPO 筛除了训练中被截断的数据,使其不参与损失计算。
@@ -92,7 +93,7 @@ $$
9293

9394
| 参数 | 类型 ||
9495
|----------------------|-----------|-------------|
95-
| `--loss_type` | `str` | `bnpo` |
96+
| `--loss_type` | `str` | `bnpo`/`dapo`|
9697
| `--epsilon_high` | `float` | `0.28` |
9798
| `--dynamic_sample` | `bool` | `true` |
9899
| `--max_resample_times` | `int` | `3` |

docs/source/Instruction/GRPO/AdvancedResearch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Advanced Research
1010
RLOO.md
1111
REINFORCEPP.md
1212
CHORD.md
13+
CISPO.md

docs/source/Instruction/GRPO/DeveloperGuide/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Developer Guide
33
.. toctree::
44
:maxdepth: 1
55

6+
loss_types.md
67
multi_turn.md
78
multi_task.md
89
reward_function.md
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Loss Types
2+
3+
GRPO训练支持五种不同的loss类型,主要区别在于归一化的维度上有所不同。
4+
5+
## 损失函数
6+
7+
token 级别上,GRPO 训练使用以下损失函数
8+
9+
$$\mathcal{L}_{i,t} = -\min\left(\rho_{i,t} A_{i,t}, \text{clip}(\rho_{i,t}, 1-\epsilon, 1+\epsilon) A_{i,t}\right)$$
10+
11+
当设置`loss_type cispo`时,使用 cispo 损失
12+
13+
$$\mathcal{L}_{i,t}^{\text{CISPO}} = -\text{detach}\left(\min(\rho_{i,t}, \epsilon_{\text{high}})\right) \cdot A_{i,t} \cdot \log \pi_\theta(y_{i,t}|y_{i,<t})$$
14+
15+
其中:
16+
- $\rho_{i,t} = \frac{\pi_\theta(y_{i,t}|y_{i,<t})}{\pi_{\theta_{\text{old}}}(y_{i,t}|y_{i,<t})}$ 是重要性采样权重
17+
- $A_{i,t}$ 是优势函数
18+
- $\epsilon$ 和 $\epsilon_{\text{high}}$ 是clipping参数
19+
- $\text{detach}(\cdot)$ 表示该项不参与梯度计算
20+
21+
## GRPO
22+
23+
`--loss_type grpo`
24+
25+
GRPO是标准的损失函数实现,对每个样本的token-level损失取平均,然后对所有样本取平均。
26+
27+
**公式:**
28+
29+
$$\mathcal{L}_{\text{GRPO}} = \frac{1}{N} \sum_{i=1}^{N} \frac{1}{T_i} \sum_{t=1}^{T_i} \mathcal{L}_{i,t}$$
30+
31+
其中:
32+
- $N$ 是批次中的样本数量
33+
- $T_i$ 是第$i$个样本的completion token数量
34+
35+
**归一化维度:** 样本维度(先对每个样本的所有token取平均,再对所有样本取平均)
36+
37+
## BNPO (Batch Normalized Policy Optimization)
38+
39+
`--loss_type bnpo`
40+
41+
BNPO将所有样本的所有token的损失直接求和,然后除以所有completion token的总数量。
42+
43+
**公式:**
44+
45+
$$\mathcal{L}_{\text{BNPO}} = \frac{\sum_{i=1}^{N} \sum_{t=1}^{T_i} \mathcal{L}_{i,t}}{\sum_{i=1}^{N} T_i}$$
46+
47+
其中:
48+
- $N$ 是批次中的样本数量
49+
- $T_i$ 是第$i$个样本的completion token数量
50+
51+
**归一化维度:** Token维度(对所有completion token取平均)
52+
53+
## DR-GRPO
54+
55+
`--loss_type dr_grpo`
56+
57+
DR-GRPO将所有样本的所有token的损失求和,然后除以批次大小乘以最大completion长度。
58+
59+
**公式:**
60+
61+
$$\mathcal{L}_{\text{DR-GRPO}} = \frac{\sum_{i=1}^{N} \sum_{t=1}^{T_i} \mathcal{L}_{i,t}}{N \times L_{\text{max}}}$$
62+
63+
其中:
64+
- $N$ 是批次中的样本数量
65+
- $T_i$ 是第$i$个样本的completion token数量
66+
- $L_{\text{max}}$ 是最大completion长度
67+
68+
**归一化维度:** 固定维度(批次大小 × 最大completion长度)
69+
70+
## CISPO
71+
72+
`--loss_type cispo`
73+
74+
CISPO损失按所有进程的completion token总数进行归一化。
75+
76+
**公式:**
77+
78+
$$\mathcal{L}_{\text{CISPO}} = \frac{\sum_{i=1}^{N} \sum_{t=1}^{T_i} \mathcal{L}_{i,t}^{\text{CISPO}}}{\sum_{\text{all processes}} \sum_{i=1}^{N_p} T_{p,i}}$$
79+
80+
其中:
81+
- $N$ 是当前进程批次中的样本数量
82+
- $T_i$ 是第$i$个样本的completion token数量
83+
- $N_p$ 是第$p$个进程的样本数量
84+
85+
**归一化维度:** 全局token维度(跨所有进程的completion token总数)
86+
87+
## DAPO
88+
89+
`--loss_type dapo`
90+
91+
DAPO与BNPO类似,使用token-level归一化,但基于全局数据(多进程)进行归一化。
92+
93+
**公式:**
94+
95+
$$\mathcal{L}_{\text{DAPO}} = \frac{\sum_{i=1}^{N} \sum_{t=1}^{T_i} \mathcal{L}_{i,t}}{\sum_{\text{all processes}} \sum_{i=1}^{N_p} T_{p,i}}$$
96+
97+
其中:
98+
- $N$ 是当前进程批次中的样本数量
99+
- $T_i$ 是第$i$个样本的completion token数量
100+
- $N_p$ 是第$p$个进程的样本数量
101+
102+
**归一化维度:** 全局token维度(跨所有进程的completion token总数)

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
572572
- reward_model_plugin: The logic for the reward model, which defaults to ORM logic. For more information, please refer to [Customized Reward Models](./GRPO/DeveloperGuide/reward_model.md#custom-reward-model).
573573
- dataset_shuffle: Whether to shuffle the dataset randomly. Default is True.
574574
- truncation_strategy: The method to handle inputs exceeding `max_length`. Supported values are `delete` and `left`, representing deletion and left-side truncation respectively. The default is `left`. Note that for multi-modal models, left-side truncation may remove multi-modal tokens and cause a shape mismatch error during model forward. With the delete strategy, over-long or encoding-failed samples are discarded, and new samples are resampled from the original dataset to maintain the intended batch size.
575-
- loss_type: The type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo'], default is 'grpo'. For details, see this [pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)
575+
- loss_type: The type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo'], default is 'grpo'. For details, refer to this [doc](./GRPO/DeveloperGuide/loss_types.md)
576576
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb/swanlab`, default is False.
577577
- Note: If `--report_to wandb/swanlab` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content.
578578
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Clipped Importance Sampling Policy Optimization (CISPO)
2+
3+
**Version requirement**: ms-swift>=3.11
4+
5+
Clipped Importance Sampling Policy Optimization (CISPO) is a reinforcement learning algorithm proposed in the [MiniMax-M1](https://arxiv.org/abs/2506.13585) paper. Compared to GRPO (Group Relative Policy Optimization), CISPO clips the importance sampling weights themselves.
6+
7+
## Algorithm Overview
8+
9+
For clarity, we explain CISPO by contrasting it with GRPO.
10+
11+
GRPO limits the magnitude of policy updates by clipping the policy ratio. Its loss function is:
12+
13+
$$
14+
\mathcal{L}_{\text{GRPO}}(\theta) = -\mathbb{E}\left[\min\left(r_t(\theta) \cdot \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot \hat{A}_t\right)\right]
15+
$$
16+
17+
where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ is the importance sampling ratio.
18+
19+
When handling long reasoning chains, this clipping approach can lead to the following issues:
20+
21+
**Gradient Suppression of Critical Tokens**: In complex reasoning tasks, certain critical low-probability tokens (such as *However, Recheck, Wait, Aha*) are crucial for triggering deep thinking and reasoning error correction. These tokens have low probability in the old policy $\pi_{\theta_{\text{old}}}$. When the new policy attempts to increase their probability, it results in a large policy ratio $r_t(\theta)$, and GRPO's clipping mechanism will discard these tokens.
22+
23+
24+
### CISPO's Solution
25+
26+
The core idea of CISPO is to clip the importance sampling weights while preserving gradient updates. Specifically, CISPO's loss function is:
27+
28+
$$
29+
\mathcal{L}_{\text{CISPO}}(\theta) = -\mathbb{E}\left[\text{detach}\left(\min(r_t(\theta), \epsilon_{\text{high}})\right) \cdot \hat{A}_t \cdot \log \pi_\theta(a_t|s_t)\right]
30+
$$
31+
32+
where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ is the importance sampling ratio.
33+
34+
**Key Mechanisms**:
35+
- Clip the importance sampling weights: $\min(r_t(\theta), \epsilon_{\text{high}})$
36+
- **Detach operation**: The clipped weights do not participate in gradient computation and serve as constant coefficients
37+
- Gradients come from the $\log \pi_\theta(a_t|s_t)$ term, ensuring all tokens contribute gradients
38+
39+
40+
## Implementation Details
41+
42+
The pseudo-code implementation of CISPO is as follows:
43+
44+
```python
45+
log_ratio = per_token_logps - old_per_token_logps
46+
importance_weights = torch.exp(log_ratio) # r_t(θ) = π_θ / π_θ_old
47+
48+
clamped_ratios = torch.clamp(importance_weights, max=epsilon_high).detach()
49+
50+
per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps
51+
```
52+
53+
## Parameter Configuration
54+
55+
CISPO training can be enabled based on `GRPOTrainer` by setting the following parameters:
56+
57+
```bash
58+
--loss_type cispo
59+
--epsilon_high 5.0
60+
```
61+
62+
> Compared to other algorithms, cispo generally uses a larger value for epsilon_high. The minimax paper does not provide specific parameter settings; the value used here refers to the experimental setup in the paper [ScaleRL](https://arxiv.org/pdf/2510.13786).
63+
64+
For other training parameters, refer to the [GRPO parameter documentation](../../Command-line-parameters.md#grpo-arguments).

docs/source_en/Instruction/GRPO/AdvancedResearch/DAPO.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ GRPO normalizes losses at the sentence level, which introduces bias based on res
4242
DAPO uses token-level normalization to avoid this bias in loss calculation.
4343

4444
Parameters:
45-
- `loss_type bnpo` enables token-level normalization.
45+
- `loss_type bnpo/dapo` enables token-level normalization.
46+
47+
> For the loss_type formula, please refer to the [documentation](../DeveloperGuide/loss_types.md).
4648
4749
## Overlong Filtering
4850
DAPO argues that forcibly truncated responses contain high reward noise, making it difficult for the model to distinguish between quality issues and length issues. To address this, DAPO filters out truncated data during training, excluding it from loss computation.
@@ -78,7 +80,7 @@ In summary, the following parameters can be set based on GRPOTrainer to implemen
7880

7981
| Parameter | Type | Value |
8082
|-----------------------|-----------|-------------|
81-
| `--loss_type` | `str` | `bnpo` |
83+
| `--loss_type` | `str` | `bnpo`/`dapo`|
8284
| `--epsilon_high` | `float` | `0.28` |
8385
| `--dynamic_sample` | `bool` | `true` |
8486
| `--max_resample_times`| `int` | `3` |

docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Advanced Research
1010
REINFORCEPP.md
1111
RLOO.md
1212
CHORD.md
13+
CISPO.md

0 commit comments

Comments
 (0)