# train.py

# 格式
# f'<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{inputs_user} <|im_end|>\n<|im_start|>assistant\n<think>\n{summary['思考摘要']}\n</think>\n\n{json.dumps(play_hand, ensure_ascii=False)}<|im_end|>'


import os
import json
import torch
from tqdm import tqdm
from trl import SFTTrainer
from tool import read_jsonl
from datasets import Dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig


train_data = read_jsonl("./train_data.jsonl")
train_data = [{
    "messages": [
        {"role": "system", "content": x['system_prompt']},
        {"role": "user", "content": x['inputs']},
        {"role": "assistant", "content": x['target']}
        ]
    } for x in train_data]
train_data = Dataset.from_list(train_data)


output_dir = "./qlora-finetuned/"
final_model_path = "./qlora-finetuned-final/"

model_name = "./Qwen3-4B-Thinking-2507"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                # 开启4bit量化
    bnb_4bit_quant_type="nf4",        # 指定量化类型为nf4
    bnb_4bit_compute_dtype=torch.bfloat16, # 指定计算时的数据类型
    bnb_4bit_use_double_quant=True,   # 开启双重量化
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# LoRA 配置
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# 训练参数配置
training_args = TrainingArguments(
    output_dir=output_dir,             # 模型输出目录
    per_device_train_batch_size=8,     # 批大小
    gradient_accumulation_steps=8,     # 梯度累积
    learning_rate=1e-5,                # 学习率
    num_train_epochs=3,                # 训练周期
    logging_steps=1000,                # 每多少步打印一次日志
    save_steps=1000,                   # 每多少步保存一次模型
    bf16=True,                         # 如果你的GPU支持，开启bf16
    warmup_ratio=0.03,                 # 预热比例
    lr_scheduler_type="constant",      # 学习率调度器
)


trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_args,
)

trainer.train()

trainer.save_model(final_model_path)
print(f"QLoRA模型已保存至: {final_model_path}")
