import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig, DataCollatorWithPadding
from peft import get_peft_model, LoraConfig, TaskType
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer, SFTConfig

# 初始化模型和 tokenizer
model_name = "Meta-Llama-3.1-8B"  # 替换为 LLaMA-3-8B 模型路径或标识符
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=nf4_config)

# 配置 LoRA
lora_config = LoraConfig(
    r=128,  # LoRA 低秩矩阵的秩
    lora_alpha=32,  # LoRA 超参数
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.1,  # LoRA dropout
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)

from benchmark_test import NLDataset
train_dataset = NLDataset(path="benchmark/rfft_tasks/train.json", train_or_test="train", num_each=None, tokenizer=tokenizer)
eval_dataset = NLDataset(path="benchmark/rfft_tasks/valid.json", train_or_test="train", num_each=None, tokenizer=tokenizer)

collator = DataCollatorForCompletionOnlyLM(response_template=" =", tokenizer=tokenizer, mlm=False)


# 设置训练参数
training_args = SFTConfig(
    output_dir="./nupa_finetuned",
    evaluation_strategy="steps",
    eval_steps=400,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    logging_dir="./logs_nupaft",
    logging_steps=100,
    save_steps=400,
    save_total_limit=3,
    bf16=True,
    load_best_model_at_end=True,
)

# 初始化 Trainer
trainer = SFTTrainer(
    model=model,
    data_collator=collator,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# 开始训练
trainer.train()