import json
import matplotlib.pyplot as plt
import pandas as pd
import yaml
from pathlib import Path
import os
from datasets import Dataset, DatasetDict,load_dataset
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from utils.loaders import load_model_tokenizer_with_lora
from transformers.integrations import TensorBoardCallback
#from utils.dataset_formatters import generate_prompt
import torch

def generate_prompt(sample, tokenizer):
    """
    Formats a single physics CoT sample into chat-style prompt.
    """
    prompt = [
        {"role": "system", "content": (
            "You are an expert physics assistant. You are given a question. "
            "Your task is to generate the final solution of the given question. "
            "Make sure the solution is mathematically accurate and correct, "
            "and at the end return the final correct option after all the intermediate steps. "
            "Let's think step by step."
        )},
        {"role": "user", "content": sample.get('question', '')},
        {"role": "assistant", "content": sample.get('cot_solution', '')}
    ]
    model_prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
    return model_prompt

def main():
    # Load config
    with open("config.yaml", "r") as f:
        config = yaml.safe_load(f)
    
    # Apply formatting (convert question + CoT into one prompt)
    def format_example(sample):
        formatted_prompt = generate_prompt(sample, tokenizer)
        return tokenizer(formatted_prompt, truncation=True, padding="max_length", max_length=1024)

    for model_name in config["model_names"]:
        try:
            # Load model and tokenizer
            model, tokenizer = load_model_tokenizer_with_lora(model_name)
            model_name = model_name.split("/")[-1]
            print(f"Model name: {model_name}")

            # Load your raw dataset

            #raw_dataset = Dataset.from_list(data)
            # with open(config["sft_dataset"], "r") as f:
            #     #print(f"Loading dataset from {config['raw_dataset']}")
            #     data = json.load(f)
            # f.close()
            raw_dataset = load_dataset("json", data_files=config["sft_dataset"],split="train")

            formatted_dataset = raw_dataset.map(format_example, remove_columns=raw_dataset.column_names)

            # Split into train and eval
            dataset = formatted_dataset.train_test_split(test_size=0.1, seed=42)
            train_dataset = dataset["train"]
            eval_dataset = dataset["test"]

            model_output_dir = Path(f"{config['output_dir']}/{model_name}/sft_lora")
            model_output_dir.mkdir(parents=True, exist_ok=True)

            # Define training args
            args = TrainingArguments(
                output_dir=model_output_dir,
                per_device_train_batch_size=config["batch_size"],
                per_device_eval_batch_size=config["batch_size"],
                gradient_accumulation_steps= 8,
                num_train_epochs=config["epochs"],
                eval_strategy="steps", # Changed from evaluation_strategy
                eval_steps=50,                 # Evaluate every 100 steps
                save_strategy="steps",
                save_steps=50,                  # Save checkpoint every 100 steps
                save_total_limit=2,              # Keep only last 2 checkpoints
                load_best_model_at_end=True,     # Restore best model after training
                metric_for_best_model="loss",    # Choose best model by eval loss
                greater_is_better=False,         # Lower loss is better
                learning_rate=config["lr"],
                bf16=config["use_bf16"],
                logging_dir=f"{model_output_dir}/logs",  # required for TensorBoard
                report_to="tensorboard",
                logging_steps=10,
            )

            # Trainer
            trainer = Trainer(
                model=model,
                tokenizer=tokenizer,
                args=args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
                callbacks=[TensorBoardCallback()]
            )

            # Train
            trainer.train()

            # Save LoRA adapter
            trainer.save_model(f"{config['output_dir']}/{model_name}/lora_sft_adapter")
            trainer.tokenizer.save_pretrained(f"{config['output_dir']}/{model_name}/lora_sft_adapter")


            print(f"Model saved to {config['output_dir']}/{model_name}/lora_sft_adapter")
            print("Training complete for model:", model_name)
            # Save log history
            log_history = trainer.state.log_history
            log_df = pd.DataFrame(log_history)
            log_df.to_csv(f"{config['output_dir']}/{model_name}/train_log.csv", index=False)

            # Plot losses
            plt.figure(figsize=(10, 6))
            plt.plot(log_df["step"], log_df["loss"], label="Training Loss")

            if "eval_loss" in log_df.columns:
                plt.plot(log_df["step"], log_df["eval_loss"], label="Eval Loss", linestyle="--")

            plt.xlabel("Steps")
            plt.ylabel("Loss")
            plt.title("Training & Evaluation Loss over Steps")
            plt.legend()
            plt.grid()
            plt.savefig(f"{config['output_dir']}/{model_name}/loss_plot.png")
            plt.close()
        
        except Exception as e:
            print(f"Error training model {model_name}: {e}")
            if "CUDA out of memory" in str(e):
                print("CUDA out of memoryerror encountered. Trying to load model with quantization.")
                try:
                    # Load model and tokenizer
                    model, tokenizer = load_model_tokenizer_with_lora(model_name,quantize=True)
                    model_name = model_name.split("/")[-1]
                    print(f"Model name: {model_name}")

                    # Load your raw dataset

                    #raw_dataset = Dataset.from_list(data)
                    # with open(config["sft_dataset"], "r") as f:
                    #     #print(f"Loading dataset from {config['raw_dataset']}")
                    #     data = json.load(f)
                    # f.close()
                    raw_dataset = load_dataset("json", data_files=config["sft_dataset"],split="train")

                    formatted_dataset = raw_dataset.map(format_example, remove_columns=raw_dataset.column_names)

                    # Split into train and eval
                    dataset = formatted_dataset.train_test_split(test_size=0.1, seed=42)
                    train_dataset = dataset["train"]
                    eval_dataset = dataset["test"]

                    model_output_dir = Path(f"{config['output_dir']}/{model_name}/sft_lora")
                    model_output_dir.mkdir(parents=True, exist_ok=True)

                    # Define training args
                    args = TrainingArguments(
                        output_dir=model_output_dir,
                        per_device_train_batch_size=config["batch_size"],
                        per_device_eval_batch_size=config["batch_size"],
                        num_train_epochs=config["epochs"],
                        evaluation_strategy="steps",
                        eval_steps=300,                 # Evaluate every 100 steps
                        save_strategy="steps",
                        save_steps=300,                  # Save checkpoint every 100 steps
                        save_total_limit=2,              # Keep only last 2 checkpoints
                        load_best_model_at_end=True,     # Restore best model after training
                        metric_for_best_model="loss",    # Choose best model by eval loss
                        greater_is_better=False,         # Lower loss is better
                        learning_rate=config["lr"],
                        bf16=config["use_bf16"],
                        logging_dir=f"{model_output_dir}/logs",  # required for TensorBoard
                        report_to="tensorboard", 
                        logging_steps=50,
                    )

                    # Trainer
                    trainer = Trainer(
                        model=model,
                        tokenizer=tokenizer,
                        args=args,
                        train_dataset=train_dataset,
                        eval_dataset=eval_dataset,
                        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
                        callbacks=[TensorBoardCallback()]
                    )

                    # Train
                    trainer.train()

                    # Save LoRA adapter
                    trainer.save_model(f"{config['output_dir']}/{model_name}/lora_sft_adapter")
                    trainer.tokenizer.save_pretrained(f"{config['output_dir']}/{model_name}/lora_sft_adapter")


                    print(f"Model saved to {config['output_dir']}/{model_name}/lora_sft_adapter")
                    print("Training complete for model:", model_name)
                    # Save log history
                    log_history = trainer.state.log_history
                    log_df = pd.DataFrame(log_history)
                    log_df.to_csv(f"{config['output_dir']}/{model_name}/train_log.csv", index=False)

                    # Plot losses
                    plt.figure(figsize=(10, 6))
                    plt.plot(log_df["step"], log_df["loss"], label="Training Loss")

                    if "eval_loss" in log_df.columns:
                        plt.plot(log_df["step"], log_df["eval_loss"], label="Eval Loss", linestyle="--")

                    plt.xlabel("Steps")
                    plt.ylabel("Loss")
                    plt.title("Training & Evaluation Loss over Steps")
                    plt.legend()
                    plt.grid()
                    plt.savefig(f"{config['output_dir']}/{model_name}/loss_plot.png")
                    plt.close()
                
                except Exception as e:
                    print(f"Error model {model_name} after quantization : {e}")
                    continue
            continue
    print("All models processed.")

if __name__ == "__main__":
    main()