import torch
import psutil
import builtins
builtins.psutil = psutil
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback
import os
import json



EXPERIMENT_MODE = "Cure-SFT"  

CONFIGS = {
    "Cure-SFT": { 
        "data_file": "wizardLM/data/cure_sft.jsonl", 
        "output_dir": "checkpoints/wizardLM/Cure-SFT/cure-sft",
        "num_epochs": 3,
    },
}

CURRENT_CFG = CONFIGS[EXPERIMENT_MODE]
print(f"\n🚀 Starting Experiment: {EXPERIMENT_MODE}")

if not os.path.exists(CURRENT_CFG['output_dir']):
    os.makedirs(CURRENT_CFG['output_dir'])

max_seq_length = 2048 
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "NousResearch/Meta-Llama-3-8B",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth", 
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

alpaca_prompt = (
    "Below is an instruction that describes a task, paired with an input that provides further context. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{}\n\n{}### Response:\n{}"
)
EOS_TOKEN = tokenizer.eos_token 

def formatting_prompts_func(examples):
    if "instruction" in examples:
        instructions = examples["instruction"]
        inputs = examples["input"] if "input" in examples else [""] * len(instructions)
        outputs = examples["output"]
    else:
        instructions = examples["original_instruction"]
        inputs = examples["original_input"]
        outputs = examples["original_output"]

    texts = []
    for instruction, input_text, output in zip(instructions, inputs, outputs):
        input_section = f"### Input:\n{input_text}\n\n" if input_text else ""
        text = alpaca_prompt.format(instruction, input_section, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

dataset = load_dataset("json", data_files=CURRENT_CFG['data_file'], split="train")
dataset = dataset.map(formatting_prompts_func, batched = True)

class LossLoggerCallback(TrainerCallback):
    def __init__(self, log_file_path):
        self.log_file_path = log_file_path
        with open(self.log_file_path, "w", encoding="utf-8") as f:
            f.write("") 

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and logs.get("loss") is not None:
            log_entry = {
                "step": state.global_step,
                "epoch": state.epoch,
                "loss": logs.get("loss"),
                "lr": logs.get("learning_rate"),
            }
            with open(self.log_file_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(log_entry) + "\n")

log_file = os.path.join(CURRENT_CFG['output_dir'], "training_loss.jsonl")

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    callbacks=[LossLoggerCallback(log_file)], 
    args = TrainingArguments(
        per_device_train_batch_size = 64,  
        gradient_accumulation_steps = 1,  
        
        warmup_ratio = 0.03,
        num_train_epochs = CURRENT_CFG['num_epochs'],
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = CURRENT_CFG['output_dir'],
        save_strategy = "no",
    ),
)

print("🔥 Training Start...")
trainer_stats = trainer.train()

print("💾 Saving Model...")
model.save_pretrained(CURRENT_CFG['output_dir'])
tokenizer.save_pretrained(CURRENT_CFG['output_dir'])
print(f"✅ Done! Check logs at: {log_file}")