# 采样

采样是SWIFT新支持的重要能力之一，这部分可以理解为`test-time compute`的落地实现。同时，该能力对RFT（强化微调）的实现也至关重要。

## 能力介绍

SWIFT的sample能力可以使用下面的例子进行：
```shell
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
```
在当前文件夹的`sample_output`目录下，会生成以时间戳为文件名的jsonl文件，该文件应该包含25行，每一行都是一个完整`messages`格式的数据。

采样的参数列表请参考[这里](命令行参数.md)。

## 环境准备

```shell
pip install ms-swift[llm] -U
```

或从源代码安装：

```shell
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e '.[llm]'
```

## 使用PRM和ORM进行结果过滤

采样重要的能力就是对过程和结果进行监督，这可以通过设置额外参数来支持。

```shell
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math
```

在当前文件夹的`sample_output`目录下，会生成以时间戳为文件名的jsonl文件，该文件**至多包含**10行，每一行都是一个完整`messages`格式的数据。
> 之所以至多包含10行，是因为虽然设置了共处理5个数据，每个数据保留2个（n_best_to_keep），但是orm可能会校验失败，失败数据不会保留到文件中。
> 另外，增加了--prm_model或--orm_model后文件格式有所不同，包含了rejected_response key，内容来自于prm评分最低的行。

## 自定义PRM或ORM

PRM和ORM的自定义可以在plugin中按照现有代码增加一个新的实现。例如：
```python
class CustomPRM:

    # 构造需要是无参的
    def __init__(self):
        # init here
        pass

    def __call__(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[Union[float, List[float]]]:
        ...


prms = {'custom': CustomPRM}
```

之后在命令行中使用`--prm_model custom`即可。

## 显存控制

如果被采样模型和PRM共同加载进显存，则可能出现OOM的问题。因此采样可以分为两段进行：

- 第一段指定`--model`和``--sampler_engine`，同时不指定`--orm_model`和`--prm_model`，仅进行采样，并存储为文件
- 第二段指定`--sampler_engine no`，指定`--orm_model`和`--prm_model`，并同时指定`--cache_files`，仅进行RM数据过滤，不重新采样

通过两段方式可以每次仅加载一个模型，防止OOM。

## 实际例子

请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。

> 注意：该脚本的实际效果和模型、数据、RM的质量强相关，因此仅作为样例出现，用户请自行修改该脚本并训练自己的RM和generator模型。

## 大模型蒸馏采样

SWIFT的sample支持使用OpenAI API的方式，用大模型蒸馏数据，如下示例：
```shell
OPENAI_API_KEY="your_api_key" \
swift sample \
    --sampler_type distill \
    --sampler_engine client \
    --model deepseek-r1 \
    --stream true \
    --dataset tastelikefeet/competition_math#5 \
    --num_return_sequences 1 \
    --temperature 0.6 \
    --top_p 0.95 \
    --engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
```
在以上示例中，base_url和model分别是api地址和模型名称，stream表示发起请求的stream参数。

注意，对于Deepseek-R1系列模型，输出会被格式化为：`<think>{reasoning_content}</think>\n\n<answer>{content}</answer>`。
