import argparse
import gc
import os
import time
import json
from pathlib import Path

import torch
from scp_dr.cp_model import CPModel
from examples.config import DATA_DIR_CACHE
from examples.data import DatasetLoader
from examples.metrics import compute_statistics, plot_and_analyze_data
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


def parse_arguments():
    """
    Parses command-line arguments for model evaluation configuration.
    """
    parser = argparse.ArgumentParser(description="Evaluation script for scp_dr models")
    parser.add_argument("--model_checkpoint1", type=str, required=True, help="Path to the first model checkpoint")
    parser.add_argument("--model_checkpoint2", type=str, help="Path to the second model checkpoint for CPModel")
    parser.add_argument("--dataset_name", type=str, required=True, default="MathAbstracts", help="Name of the dataset")
    parser.add_argument("--n_test_samples", type=int, default=500, help="Number of test samples")
    parser.add_argument("--output_dir", type=str, default="./eval", help="Directory to save evaluation results")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation")
    parser.add_argument("--grid_size", type=int, default=10, help="Grid size for grid search in CPModel")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output during evaluation")
    parser.add_argument("--fixed_coef", type=float, default=None, help="ficed coef for the first model")
    parser.add_argument("--use_relative_probs", action="store_true", help="use relative probabilities when aggregating")
    parser.add_argument("--use_minimum", action="store_true", help="use minimum as the aggregation function instead of weighted average")
    return parser.parse_args()


def init_tokenizer(model_checkpoint):
    tokenizer = AutoTokenizer.from_pretrained(
        model_checkpoint,
        padding_side="left",
        trust_remote_code=True
    )
    # Do NOT add special tokens if the checkpoint was trained without them
    # tokenizer.add_special_tokens({"sep_token": "[SEP]", "pad_token": "[PAD]"})
    print("Tokenizer vocab size:", len(tokenizer))
    return tokenizer

def load_model(path, tokenizer):
    path = Path(path)
    if (path / "adapter_config.json").exists():
        print(f"Loading PEFT model from {path}")

        with open(path / "adapter_config.json") as f:
            cfg = json.load(f)

        base_model_name = cfg["base_model_name_or_path"]
        print("base_model:", base_model_name)
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).half()
        base_model.resize_token_embeddings(len(tokenizer))
        model = PeftModel.from_pretrained(base_model, path)
    else:
        print(f"Loading model from {path}")
        model = AutoModelForCausalLM.from_pretrained(
            path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).half()
    model.resize_token_embeddings(len(tokenizer))
    return model

def load_models(args, tokenizer):
    """
    Loads the primary model or CPModel based on provided arguments.
    """
    if args.model_checkpoint2:
        model1 = load_model(args.model_checkpoint1, tokenizer)
        model2 = load_model(args.model_checkpoint2, tokenizer)

        model_name = f"{args.dataset_name}_cp_model"
        return (
            CPModel(
                model1=model1,
                model2=model2,
                grid_size=args.grid_size,
                verbose=args.verbose,
                fixed_coef=args.fixed_coef,
                use_relative_probs=args.use_relative_probs,
                use_minimum=args.use_minimum,
            ),
            model_name,
        )
    else:
        model = load_model(args.model_checkpoint1, tokenizer)
        model_name = f"{args.dataset_name}_single_model"
        return model, model_name


def evaluate_datasets(train_dataset, validation, model, tokenizer, eval_dir, batch_size):
    """
    Evaluates multiple datasets and saves results to CSV files.
    """
    eval_folder_names = []
    datasets = {"train": train_dataset, "validation": validation}

    for name, data in datasets.items():
        file_name = os.path.join(eval_dir, f"{name}.csv")
        if not os.path.isfile(file_name):
            print(f"Evaluating {name} set...")
            start_time = time.time()
            eval_res = compute_statistics(model=model, data=data, tokenizer=tokenizer, batch=batch_size)
            eval_res.to_csv(file_name)
            print(f"Evaluation of {name} completed in {time.time() - start_time:.2f} seconds.")
            del eval_res
            torch.cuda.empty_cache()
            gc.collect()
        eval_folder_names.append(f"{name}")

    plot_and_analyze_data(eval_dir, eval_folder_names)
    print("Evaluation completed. Results saved in:", eval_dir)


def main():
    args = parse_arguments()
    print("Evaluation configuration:", args)
    
    # Set up evaluation directory
    avg_type = "minimum" if args.use_minimum else "geometric"
    prob_type = "relative" if args.use_relative_probs else "raw"
    coef_type = "adaptive" if args.fixed_coef is None else "fixed_" + str(args.fixed_coef).replace(".","") 
    eval_dir = os.path.join(args.output_dir, f"{args.dataset_name}_{prob_type}_{avg_type}_{coef_type}_evaluation")
    os.makedirs(eval_dir, exist_ok=True)

    # Initialize tokenizer
    tokenizer = init_tokenizer(args.model_checkpoint1)

    # Load datasets
    dataloader = DatasetLoader()
    _, train_dataset, validation_dataset = dataloader.load_or_create_datasets(
        dataset_name=args.dataset_name,
        ntrain=args.n_test_samples,
    )
    print("first training example:", train_dataset[0])
    print("first validation example:", validation_dataset[0])

    # Load model(s)
    model, _ = load_models(args, tokenizer)

    # Run evaluation
    evaluate_datasets(
        train_dataset=train_dataset,
        validation=validation_dataset,
        model=model,
        tokenizer=tokenizer,
        eval_dir=eval_dir,
        batch_size=args.batch_size,
    )


if __name__ == "__main__":
    main()
