"""Evaluate outputs from efficient heads on different datasets"""

import argparse
import json
import os
from datetime import datetime
from typing import Tuple

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, pipeline
from utils import get_clustering_cache_dir

from efficient_heads.eval_outputs import (
    NextTokenEvaluator as Evaluator,
)
from efficient_heads.flash_head import (
    get_flash_head_model_and_tokenizer,
    get_spherical_k_means_model_and_tokenizer,
)
from efficient_heads.midx_head import get_midx_model_and_tokenizer
from efficient_heads.svd_softmax import get_svd_softmax_model_and_tokenizer
from efficient_heads.vocab_pruning import get_vocab_pruning_model_and_tokenizer

torch.manual_seed(0)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
CLUSTERING_CACHE_DIR = get_clustering_cache_dir()
CLUSTERING_CACHE_DIR = get_clustering_cache_dir(n_clusters=8192)
CLUSTERING_CACHE_DIR_SKMEANS_100 = get_clustering_cache_dir(n_clusters=100)
CLUSTERING_CACHE_DIR_SKMEANS_200 = get_clustering_cache_dir(n_clusters=200)


model_tokenizer_constructors = {
    "flash_head": get_flash_head_model_and_tokenizer,
    "spherical_k_means_100": get_spherical_k_means_model_and_tokenizer,
    "spherical_k_means_200": get_spherical_k_means_model_and_tokenizer,
    "svd_softmax": get_svd_softmax_model_and_tokenizer,
    "midx_head": get_midx_model_and_tokenizer,
    "vocab_pruning": get_vocab_pruning_model_and_tokenizer,
}

head_kwargs = {
    "flash_head": {
        "cache_dir": CLUSTERING_CACHE_DIR,
        "n_clusters": 8192,
        "n_probes": 256,
    },
    "spherical_k_means_100": {
        "cache_dir": CLUSTERING_CACHE_DIR_SKMEANS_100,
        "n_clusters": 100,
    },
    "spherical_k_means_200": {
        "cache_dir": CLUSTERING_CACHE_DIR_SKMEANS_200,
        "n_clusters": 200,
    },
    "vocab_pruning": {
        "vocab_size": 64000,  # Dataset name and split is based on args
    },
    "svd_softmax": {"window": 256, "top_n": 16000},
    "midx_head": {"n_codewords": 32},
}


def get_dataset_name_and_split_based_on_dataset_type(
    dataset_name: str,
) -> Tuple[str]:
    """
    Get the dataset name and split based on the dataset we are evaluating on.
    """
    if "alpaca_eval" in dataset_name:
        dataset_name = "tatsu-lab/alpaca"

    dataset_split = "dev" if "flores_plus" in dataset_name else "train"
    return dataset_name, dataset_split


def get_model_head_tokenizer(head_type: str, args) -> pipeline:
    """Get a model, head and tokenizer."""
    model_id = "meta-llama/Llama-3.2-1B-Instruct"
    constructor = model_tokenizer_constructors[head_type]
    kwargs = head_kwargs.get(head_type, {})

    if head_type == "vocab_pruning":
        dataset_name, dataset_split = (
            get_dataset_name_and_split_based_on_dataset_type(args.dataset)
        )
        kwargs.update(
            {"dataset_name": dataset_name, "dataset_split": dataset_split}
        )
    efficient_model, tokenizer = constructor(
        model_id=model_id, device=DEVICE, **kwargs
    )
    head = efficient_model.lm_head
    head = head.to(DTYPE)
    baseline_model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=DTYPE
    ).to(DEVICE)
    return baseline_model, head, tokenizer


def get_dataloader(
    dataset_name: str = "tatsu-lab/alpaca_eval", batch_size: int = 1
):
    """Get a dataloader to evaluate"""
    split = "train"

    if "flores_plus" in dataset_name:
        split = "dev"

    if "eval" in dataset_name:
        split = "eval"

    dataset = load_dataset(dataset_name, split=split)
    return DataLoader(
        dataset,
        collate_fn=lambda batch: [
            (
                item["text"]
                if "flores_plus" in dataset_name
                else item["instruction"] + "\n\n" + item.get("input", "")
            )
            for item in batch
        ],
        batch_size=batch_size,
    )


def parse_args():
    """Parse all arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="tatsu-lab/alpaca_eval",
        help="Dataset name from HuggingFace, e.g., openlanguagedata/flores_plus",
    )
    parser.add_argument(
        "--save_path", type=str, default="evaluation_results.json"
    )
    parser.add_argument("--num_batches", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=512)
    parser.add_argument(
        "--head_type",
        nargs="+",
        default=list(model_tokenizer_constructors.keys()),
        help=(
            "Head types to evaluate. Options: "
            + ", ".join(model_tokenizer_constructors.keys())
            + ". Default: all."
        ),
    )
    return parser.parse_args()


def main():
    """Main function to run evaluations"""
    args = parse_args()
    dataloader = get_dataloader(args.dataset, args.batch_size)

    all_results = []
    if os.path.exists(args.save_path):
        try:
            with open(args.save_path, encoding="utf-8") as f:
                all_results = json.load(f)
        except json.JSONDecodeError:
            print("Warning: save file is not valid JSON. Starting fresh.")

    for head_type in args.head_type:
        model, head, tokenizer = get_model_head_tokenizer(head_type, args)
        print("Head type:", head_type)
        evaluator = Evaluator(
            baseline_model=model,
            custom_head=head,
            tokenizer=tokenizer,
            head_type=head_type,
        )

        eval_result = evaluator.evaluate(
            dataloader, num_batches=args.num_batches
        )

        eval_result.update(
            {
                "head_type": head_type,
                "dataset": args.dataset,
                "timestamp": datetime.utcnow().isoformat(),
                "batch_size": args.batch_size,
                "num_batches": args.num_batches,
                "max_new_tokens": args.max_new_tokens,
                **head_kwargs.get(head_type, {}),
            }
        )
        print(eval_result)
        all_results.append(eval_result)
        with open(args.save_path, "w", encoding="utf-8") as f:
            json.dump(all_results, f, indent=2)


if __name__ == "__main__":
    main()
