from unsloth import FastLanguageModel, PatchFastRL
from datasets import Dataset,load_from_disk
from datetime import datetime
import os
PatchFastRL("GRPO", FastLanguageModel)
from unsloth import is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
cot_path="src/data/prompts/sqv.txt"
with open(cot_path, "r", encoding='utf-8') as f:  
    SYSTEM_PROMPT = f.read()   #读取文本
task_specific_prompt=""
SYSTEM_PROMPT=SYSTEM_PROMPT.format(task_specific_prompt=task_specific_prompt)
print(SYSTEM_PROMPT)
def load_peft_model(model_name):
  lora_rank = 64 # Larger rank = smarter, but slower
  model, tokenizer = FastLanguageModel.from_pretrained(
      model_name = model_name,
      max_seq_length = max_seq_length,
      load_in_4bit = True, # False for LoRA 16bit
      fast_inference = True, # Enable vLLM fast inference
      max_lora_rank = lora_rank,
      gpu_memory_utilization =0.5, # Reduce if out of memory
  )
  tokenizer.pad_token = tokenizer.eos_token
  model = FastLanguageModel.get_peft_model(
      model,
      r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
      target_modules = [
          "q_proj", "k_proj", "v_proj", "o_proj",
          "gate_proj", "up_proj", "down_proj",
      ], # Remove QKVO if out of memory
      lora_alpha = lora_rank,
      use_gradient_checkpointing = "unsloth", # Enable long context finetuning
      random_state = 3407,
  )
  return model,tokenizer
def get_sft_data(data) -> Dataset:
    EOS_TOKEN = tokenizer.eos_token 
   
    prompt="Context: {c}\n\nQuestion: {q}\n\nAnswer: {a}"
    data = data.map(lambda x:{"text": SYSTEM_PROMPT+prompt.format(c=x["n_fact"],q=x['question'],a=x['n_answer'])+ EOS_TOKEN} )
    return data
def train(model,tokenizer,data,res_dir):
    trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = data,
        dataset_text_field = "text",
        max_seq_length = max_seq_length,
        dataset_num_proc = 2,
        packing = False, # Can make training 5x faster for short sequences.
        args = TrainingArguments(
            per_device_train_batch_size = 2,
            gradient_accumulation_steps = 4,
            warmup_steps = 5,
            num_train_epochs = 2, # Set this for 1 full training run.
            learning_rate = 2e-4,
            fp16 = not is_bfloat16_supported(),
            bf16 = is_bfloat16_supported(),
            logging_steps = 10,
            optim = "adamw_8bit",
            weight_decay = 0.01,
            lr_scheduler_type = "linear",
            seed = 3407,
            save_strategy="steps",
            save_steps=50,
            output_dir = res_dir,
            
        ),
    )
    trainer_stats = trainer.train()
    lora_dir=os.path.join(res_dir,"adapter")
    model.save_pretrained_merged(lora_dir, tokenizer, save_method = "merged_16bit",)
if __name__ == "__main__":
    STAGE="SFT"
    base_model_path=YOUR_MODEL_PATH
    base_model=base_model_path.split("/")[-1]
    model,tokenizer=load_peft_model(base_model_path)
   
    data_path=YOUR_DATA_PATH
    res_dir=f"src/ckpt/{STAGE}/{base_model}"
    if not os.path.exists(res_dir):
      os.mkdir(res_dir)
    today = datetime.now().date()
    res_dir=os.path.join(res_dir,str(today))
    if not os.path.exists(res_dir):
      os.mkdir(res_dir)
    # get data
    data=get_sft_data(load_from_disk(data_path))
    #start train
    train(model,tokenizer,data,res_dir)
    