import torch
import psutil
import builtins
builtins.psutil = psutil
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback 
import os
import json


EXPERIMENT_MODE = "Cure-SFT"

CONFIGS = {
    "Cure-SFT":{
        "data_file": "alpaca/data/cure-sft.jsonl",
        "output_dir": "checkpoints/alpaca/Cure-SFT/cure-sft",
        "num_epochs": 3,
    },
}

CURRENT_CFG = CONFIGS[EXPERIMENT_MODE]
print(f"\n🚀 Starting Experiment: {EXPERIMENT_MODE}")
print(f"📂 Data: {CURRENT_CFG['data_file']}")
print(f"💾 Output: {CURRENT_CFG['output_dir']}\n")


if not os.path.exists(CURRENT_CFG['output_dir']):
    os.makedirs(CURRENT_CFG['output_dir'])


max_seq_length = 2048
dtype = None 
load_in_4bit = True 

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "NousResearch/Meta-Llama-3-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth", 
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)


alpaca_prompt = (
    "Below is an instruction that describes a task, paired with an input that provides further context. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{}\n\n{}### Response:\n{}"
)

EOS_TOKEN = tokenizer.eos_token 

def formatting_prompts_func(examples):

    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]

    texts = []
    for instruction, input_text, output in zip(instructions, inputs, outputs):
        input_section = f"### Input:\n{input_text}\n\n" if input_text else ""
        text = alpaca_prompt.format(instruction, input_section, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

dataset = load_dataset("json", data_files=CURRENT_CFG['data_file'], split="train")
dataset = dataset.map(formatting_prompts_func, batched = True)

print(f"✅ Loaded {len(dataset)} samples.")


class LossLoggerCallback(TrainerCallback):
    def __init__(self, log_file_path):
        self.log_file_path = log_file_path
        with open(self.log_file_path, "w", encoding="utf-8") as f:
            f.write("") 

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            log_entry = {
                "step": state.global_step,
                "epoch": state.epoch,
                "loss": logs.get("loss"),
                "lr": logs.get("learning_rate"),
            }
            if log_entry["loss"] is not None:
                with open(self.log_file_path, "a", encoding="utf-8") as f:
                    f.write(json.dumps(log_entry) + "\n")


log_file = os.path.join(CURRENT_CFG['output_dir'], "training_loss.jsonl")
print(f"📝 Logging training loss to: {log_file}")


trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    callbacks=[LossLoggerCallback(log_file)], 
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 16,
        warmup_ratio = 0.03,
        num_train_epochs = CURRENT_CFG['num_epochs'],
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = CURRENT_CFG['output_dir'],
        save_strategy = "no", 
    ),
)

print("🔥 Training Start...")
trainer_stats = trainer.train()

print("💾 Saving Model...")
model.save_pretrained(CURRENT_CFG['output_dir'])
tokenizer.save_pretrained(CURRENT_CFG['output_dir'])
print(f"✅ Done! Check logs at: {log_file}")