from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import torch
from datasets import Dataset
import os
import json

MODEL_PATH = "/path/to/your/model/DeepSeek-R1-Distill-Qwen-14B"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    local_files_only=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    padding_side='left',
    local_files_only=True,
)


lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

def load_dataset(file_path: str) -> Dataset:
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    dataset = Dataset.from_list(data)
    return dataset


def sft_train(learning_rate=5e-5, gradient_accumulation_steps=8, output_dir="/path/to/save/sft_model"):
    training_config = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=gradient_accumulation_steps,
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        learning_rate=learning_rate,
        eval_strategy="no",
        save_strategy="epoch",
        save_only_model=False,
        logging_steps=10,
        report_to="tensorboard",
        dataloader_drop_last=True,
        remove_unused_columns=True,
        fp16=False,
        bf16=True,
        save_total_limit=3,
        logging_first_step=True,
    )

    train_dataset = load_dataset("/path/to/sft_train_data.json")
    print(f"size of dataset: {len(train_dataset)}")

    print("Starting LoRA SFT training setup...")
    

    trainer = SFTTrainer(
        model=model,
        args=training_config,
        train_dataset=train_dataset,
        processing_class=tokenizer,
    )

    print("Starting LoRA SFT training...")
    trainer.train()
    
    final_checkpoint_dir = f"{training_config.output_dir}/final_checkpoint"
    trainer.save_model(final_checkpoint_dir)
    
    print("LoRA training completed!")
    print(f"model save to: {final_checkpoint_dir}")

if __name__ == "__main__":
    your_config = {
        "learning_rate": None,
        "gradient_accumulation_steps": None,
        "output_dir": "/path/to/save/sft_model",
    }
    sft_train(
        learning_rate=your_config["learning_rate"],
        gradient_accumulation_steps=your_config["gradient_accumulation_steps"],
        output_dir=your_config["output_dir"]
    )