from utils import load_single_dataset, save_dataset
from transformers import AutoTokenizer

dataset = load_single_dataset("~/datasets/launch-thinkprm-1K-verification-cots", dataset_split="train")

def format_example(example):
    """
    按照论文 Figure 12 和 Figure 13 的格式进行转换
    """
    # 处理输入部分 (Input Prompt)
    # 包含问题、解答前缀以及引导语
    problem = example['problem'].strip()
    prefix = example['prefix'].strip()
    
    # 构造输入模板
    input_text = f"Problem: {problem}\nSolution:\n{prefix}\n"
    
    # 处理输出部分 (Target Output)
    # 论文在预处理时会添加 <think> 标签 [cite: 635]
    # prefix_label 在数据集中是布尔值，需要转为 "Yes" 或 "No"
    cot_reasoning = example['cot'].strip()
    is_correct = "Yes" if example['prefix_label'] else "No"
    
    # 构造目标文本
    # 注意：数据集中的 cot 字段通常已经包含了具体的步级判定 \boxed{correct} 等
    output_text = f"{cot_reasoning}"
    output_text = '<|end_of_thought|>'.join(output_text.replace('<think>', '<|begin_of_thought|>', 1).rsplit('</think>', 1))
    return {
        "instruction": input_text,
        "output": output_text,
    }

# 2. 应用格式化
formatted_dataset = dataset.map(format_example, remove_columns=dataset.column_names)

# 3. 打印一个样本查看结果
print("--- 训练输入 (Input) ---")
print(formatted_dataset[0]['instruction'])
print("\n--- 训练输出 (Output) ---")
print(formatted_dataset[0]['output'])

# 4. 保存为 JSONL 文件以便训练
save_dataset(formatted_dataset, "~/LLaMA-Factory-250514/data/llama32-thinkprm.jsonl")

# 建议使用与 ThinkPRM 基础模型一致的 tokenizer (例如 llama3.2 系列)
tokenizer = AutoTokenizer.from_pretrained("~/7b_model/keeeeenw-Llama-3.2-1B-Instruct-Open-R1-Distill")

def get_total_tokens(example):
    # 按照你的 format_example，计算拼接后的总长度
    full_text = example['instruction'] + example['output']
    return len(tokenizer.encode(full_text))

# 计算数据集中所有样本的 token 长度
all_lengths = sorted([get_total_tokens(ex) for ex in formatted_dataset])

# 获取统计结果
max_tokens = max(all_lengths)
min_tokens = min(all_lengths)
avg_tokens = sum(all_lengths) / len(all_lengths)

print(f"--- 数据集 Token 统计 ---")
print(f"样本总数: {len(all_lengths)}")
print(f"最长样本长度: {max_tokens} tokens")
print(f"最短样本长度: {min_tokens} tokens")
print(f"平均样本长度: {avg_tokens:.2f} tokens")

# 打印最长样本的索引，方便你检查
max_idx = all_lengths.index(max_tokens)
print(f"最长样本在 formatted_dataset 中的索引为: {max_idx}")

print(all_lengths)