"""Evaluate efficient head models on LM eval tasks"""

import argparse
import json
import os
from datetime import datetime
from typing import Optional, Tuple, Union

import numpy as np
import torch
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# from efficient_heads.fgd_head import get_fgd_model_and_tokenizer
from efficient_heads.flash_head import (
    get_flash_head_model_and_tokenizer,
    get_spherical_k_means_model_and_tokenizer,
)
from efficient_heads.midx_head import get_midx_model_and_tokenizer
from efficient_heads.svd_softmax import get_svd_softmax_model_and_tokenizer
from efficient_heads.vocab_pruning import (
    get_vocab_pruning_model_and_tokenizer,
)

torch.manual_seed(0)


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16

model_tokenizer_constructors = {
    "flash_head": get_flash_head_model_and_tokenizer,
    "spherical_k_means_100": get_spherical_k_means_model_and_tokenizer,
    "spherical_k_means_200": get_spherical_k_means_model_and_tokenizer,
    "svd_softmax": get_svd_softmax_model_and_tokenizer,
    "midx_head": get_midx_model_and_tokenizer,
    "vocab_pruning": get_vocab_pruning_model_and_tokenizer,
    # "fgd_head": get_fgd_model_and_tokenizer,
}

model_kwargs = {
    "flash_head": {
        "n_clusters": 4096,  # int(16384/2.0),
        "n_probes": 512,
    },
    "spherical_k_means_100": {
        "n_clusters": 100,
    },
    "spherical_k_means_200": {
        "n_clusters": 200,
    },
    "vocab_pruning": {
        "vocab_size": 64000,  # Dataset name and split is based on args
    },
    "svd_softmax": {"window": 256, "top_n": 12000},
    "midx_head": {"n_codewords": 32},
    "fgd_head": {
        "K": 384,
        "ef_search": 300,
        "index_M": 40,
        "ef_construction": 300,
    },
}


def convert(obj):
    if isinstance(obj, dict):
        return {key: convert(val) for key, val in obj.items()}
    if isinstance(obj, list):
        return [convert(val) for val in obj]
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.bool_):
        return bool(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    return obj


def run_evaluation(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    task_list: Optional[list] = None,
    limit: Optional[int] = None,
    write_out: bool = False,
    log_samples: bool = False,
    batch_size: Union[str, int] = "auto",
    use_cache: str = None,
):
    if task_list is None:
        task_list = [
            "bbh"
        ]  # ["mmlu_pro", "bbh", "boolq", "hellaswag", "truthfulqa"]
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    lm_model = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )
    results = evaluator.simple_evaluate(
        model=lm_model,
        tasks=task_list,
        batch_size=batch_size,
        limit=limit,
        write_out=write_out,
        log_samples=log_samples,
        use_cache=use_cache,
    )
    return results


def get_dataset_name_and_split_based_on_dataset_type(
    dataset_name: str,
) -> Tuple[str]:
    """
    Get the dataset name and split based on the dataset we are evaluating on.
    """
    if "alpaca_eval" in dataset_name:
        dataset_name = "tatsu-lab/alpaca"
    dataset_split = "dev" if "flores_plus" in dataset_name else "train"
    return dataset_name, dataset_split


def get_model_and_tokenizer(model_id: str, head_type: str, kwargs) -> pipeline:
    """Get a model and tokenizer."""
    constructor = model_tokenizer_constructors[head_type]
    model, tokenizer = constructor(model_id=model_id, device=DEVICE, **kwargs)
    return model.to(DTYPE), tokenizer


def parse_args():
    parser = argparse.ArgumentParser(
        description="CLI for evaluating efficient heads on LM-eval tasks"
    )
    parser.add_argument(
        "--model-id",
        type=str,
        default="google/gemma-3-270m-it",  # "meta-llama/Llama-3.2-1B-Instruct",
        help="HuggingFace model identifier",
    )
    parser.add_argument(
        "--head-types",
        type=str,
        nargs="+",
        choices=list(model_tokenizer_constructors.keys()),
        default=["flash_head"],
        help="Which head types to evaluate (you can specify multiple)",
    )
    parser.add_argument(
        "--clustering-cache-dir",
        type=str,
        help="Clustering cache directory for FlashHead",
        default="gemma-3-270m-it-cluster-4096-eq/",
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=10,  # None,  # 10,
        help="Limit number of examples per task",
    )
    parser.add_argument(
        "--tasks",
        type=str,
        nargs="+",
        default=["bbh"],
        help="List of LM-eval task names (e.g. boolq, mmlu_pro, bbh)",
    )
    parser.add_argument(
        "--batch-size",
        type=str,
        # default="auto",
        default=8,
        help="Batch size for HFLM (or 'auto')",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=os.path.expanduser("~/lm_eval_results"),
        help="Directory to write out JSON results",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    for head_type in args.head_types:
        kwargs = model_kwargs.get(head_type, {})
        # Special case for vocab_pruning: set split
        if head_type == "vocab_pruning":
            ds_name, ds_split = (
                get_dataset_name_and_split_based_on_dataset_type(
                    kwargs["dataset_name"]
                )
            )
            kwargs.update({"dataset_name": ds_name, "dataset_split": ds_split})

        if head_type == "flash_head":
            if args.clustering_cache_dir is None:
                raise ValueError(
                    "FlashHead requires a path to a clustering cache. "
                    "Generate it with `efficient_heads.flash_head._get_centroids` and pass as"
                    "--clustering-cache-dir path/to/clustering_cache/"
                )
            kwargs.update({"cache_dir": args.clustering_cache_dir})

        if "spherical_k_means" in head_type:
            raise ValueError(
                "Spherical K means requires a path to a clustering cache. "
                "Generate it with `efficient_heads.flash_head._get_centroids` and pass as"
                "--clustering-cache-dir path/to/clustering_cache/"
            )
            kwargs.update({"cache_dir": args.clustering_cache_dir})

        # Load model + tokenizer
        model, tokenizer = get_model_and_tokenizer(
            args.model_id, head_type, kwargs
        )
        # Build output filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(
            args.output_dir, f"{head_type}_results_{timestamp}.json"
        )
        # Run
        print("Running evaluation for ", head_type, "task", args.tasks)
        print("Parameters ", kwargs)
        results = run_evaluation(
            model=model,
            tokenizer=tokenizer,
            task_list=args.tasks,
            limit=args.limit,
            batch_size=args.batch_size,
        )
        print(results)
        results = convert(results["results"])
        results.update({"head_type": head_type, "kwargs": kwargs})
        print(results)
        with open(output_file, "w") as f:
            json.dump(results, f, indent=2)
