import json
import yaml
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import DPOTrainer, DPOConfig
from peft import PeftModel
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import os

def main():
    # Load config
    with open("config_dpo.yaml", "r") as f:
        config = yaml.safe_load(f)

    def chatml_format(example):
        # System message
        system_prompt = {
            "role": "system",
            "content": (
                "You are an expert physics assistant. You are given a question. "
                "Your task is to generate the final solution of the given question. "
                "Make sure the solution is mathematically accurate and correct, "
                "and at the end return the final correct option after all the intermediate steps. "
                "Let's think step by step."
            )
        }
        user_prompt = {"role": "user", "content": example["question"]}

        full_prompt = tokenizer.apply_chat_template([system_prompt, user_prompt], tokenize=False, add_generation_prompt=True)

        return {
            "prompt": full_prompt,
            "chosen": example["cot_solution"],
            "rejected": example["incorrect_solution"]
        }


    for model_name in (config["model_names"]):
        # Load base + adapter model
        try: 
            model = model_name.split("/")[-1]
            adapter_path = os.path.join(config["model_dir"], model, "lora_sft_adapter")
            base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,device_map="auto",local_files_only=True)
            model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True, adapter_name="training2")
            model.load_adapter(adapter_path, adapter_name="reference", is_trainable=False)
            # Load adapters:
            # model.load_adapter(adapter_path, adapter_name="training2", is_trainable=True)
            #model.load_adapter(config["model_name"], adapter_name="reference", is_trainable=False)  # <-- base model adapter

            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"

            model_name = model_name.split("/")[-1]
            #print(f"Training LoRA on {model_name}...")

            
            # Load raw dataset
            raw_dataset = load_dataset("json", data_files=config["dpo_dataset"], split="train")
            formatted_dataset = raw_dataset.map(chatml_format, remove_columns=raw_dataset.column_names)

            # Split into train/eval
            split_dataset = formatted_dataset.train_test_split(test_size=0.1, seed=42)
            train_dataset = split_dataset["train"]
            eval_dataset = split_dataset["test"]
            

            model_output_dir = Path(f"{config['output_dir']}/{model_name}/dpo_lora")
            model_output_dir.mkdir(parents=True, exist_ok=True)

            # Training arguments
            training_args = DPOConfig(
                output_dir=model_output_dir,
                per_device_train_batch_size=config["batch_size"],
                per_device_eval_batch_size=config["batch_size"],
                num_train_epochs=config["epochs"],
                learning_rate=config["lr"],
                gradient_accumulation_steps=4,
                gradient_checkpointing=True,
                eval_strategy="steps",
                eval_steps=100,
                save_strategy="steps",
                lr_scheduler_type="cosine",
                save_steps=100,
                save_total_limit=2,
                warmup_steps=250,
                weight_decay=0.01,

                load_best_model_at_end=True,
                metric_for_best_model="rewards/margins",  # or "rewards/accuracies"
                greater_is_better=True,
                max_grad_norm=1.0,
                optim="paged_adamw_32bit",
                bf16=config["use_bf16"],
                logging_steps=10,
                beta=config["beta"],
                max_prompt_length=512,
                max_length=1024,
                model_adapter_name="training2",
                ref_adapter_name="reference",
                report_to="none"
            )

            # DPOTrainer
            trainer = DPOTrainer(
                    model=model,
                    args=training_args,
                    train_dataset=train_dataset,
                    eval_dataset=eval_dataset,
                    processing_class=tokenizer,
                    
                )
            # Train
            trainer.train()

            # Save LoRA adapter
            trainer.save_model(f"{config['output_dir']}/{model_name}/lora_dpo_adapter")
            trainer.tokenizer.save_pretrained(f"{config['output_dir']}/{model_name}/lora_dpo_adapter")

            # Save log history
            log_history = trainer.state.log_history
            log_df = pd.DataFrame(log_history)
            log_df.to_csv(f"{config['output_dir']}/{model_name}/train_log.csv", index=False)

            # Plot losses
            plt.figure(figsize=(10, 6))
            plt.plot(log_df["step"], log_df["loss"], label="Training Loss")

            if "eval_loss" in log_df.columns:
                plt.plot(log_df["step"], log_df["eval_loss"], label="Eval Loss", linestyle="--")

            plt.xlabel("Steps")
            plt.ylabel("Loss")
            plt.title("Training & Evaluation Loss over Steps")
            plt.legend()
            plt.grid()
            plt.savefig(f"{config['output_dir']}/{model_name}/loss_plot.png")
            plt.close()
            print(f"Training completed for {model_name}. Model and logs saved to {model_output_dir}")
            
        except Exception as e:
            print(f"Error training model {model_name}: {e}")
            continue

if __name__ == "__main__":
    main()