"""Evaluate latency of different efficient heads."""

import argparse
import json
import logging
from pathlib import Path

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from efficient_heads.eval_outputs import NextTokenEvaluator
from efficient_heads.experiments import create_experiments
from efficient_heads.fgd import get_fgd_model_and_tokenizer
from efficient_heads.flash_head import FlashHead, get_flash_head_parameters
from efficient_heads.midx_head import MIDXHead
from efficient_heads.svd_softmax import SVDSoftmaxLayer
from efficient_heads.vocab_pruning import VocabPruningHead


def parse_cli_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Top-K Containment evaluation of different efficient heads."
    )
    parser.add_argument("--dataset", default="alpaca")
    parser.add_argument("--cluster-cache", help="Path to cluster cache.")
    parser.add_argument("--experiment", default="small")
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Whether verbose output is enabled.",
    )

    return parser.parse_args()


def get_dataloader(cli_args) -> DataLoader:
    """Return data loader."""

    dataset = cli_args.dataset
    if dataset == "alpaca":
        dataset_name = "tatsu-lab/alpaca_eval"
        dataset_split = "eval"
        collate_fn = lambda batch: [
            item["instruction"] + "\n\n" + item.get("input", "")
            for item in batch
        ]
    elif dataset == "math":
        dataset_name = "lighteval/MATH-Hard"
        dataset_split = "test"
        collate_fn = lambda batch: [item["problem"] for item in batch]
    elif dataset == "xnli":
        dataset_name = "facebook/xnli"
        dataset_split = "test"
        collate_fn = lambda batch: [item["premise"] for item in batch]

    if dataset == "xnli":
        dataset = load_dataset(
            dataset_name, "all_languages", split=dataset_split
        )
    else:
        dataset = load_dataset(dataset_name, split=dataset_split)

    return DataLoader(
        dataset,
        collate_fn=collate_fn,
        batch_size=1,
    )


def main():
    """Evaluate different efficient heads."""
    cli_args = parse_cli_args()
    model_id = "meta-llama/Llama-3.2-1B-Instruct"
    if cli_args.verbose:
        logging.basicConfig(level=logging.INFO)

    experiment = create_experiments(
        cli_args.experiment, Path(cli_args.cluster_cache)
    )

    results = {}

    experiment_iterator = tqdm(experiment)
    for pipeline_config in experiment_iterator:

        head_type = pipeline_config.head_type

        experiment_iterator.set_description(
            desc=f"{head_type} {pipeline_config.kwargs}"
        )

        if head_type not in results:
            results[head_type] = []

        model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=torch.bfloat16, device_map="cuda"
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)

        if head_type == "midx":
            head = MIDXHead(model.lm_head, **pipeline_config.kwargs)
        elif head_type == "flash_head":
            head = FlashHead(
                model.lm_head,
                n_probes=pipeline_config.kwargs["n_probes"],
                forward_type="partial_logits",
                **get_flash_head_parameters(
                    lm_head=model.lm_head,
                    tokenizer=tokenizer,
                    n_clusters=pipeline_config.kwargs["n_clusters"],
                    cache_dir=pipeline_config.kwargs["cache_dir"],
                ),
            )
        elif head_type == "spherical_k_means":
            head = FlashHead(
                model.lm_head,
                n_probes=1,
                forward_type="partial_logits",
                **get_flash_head_parameters(
                    lm_head=model.lm_head,
                    tokenizer=tokenizer,
                    n_clusters=pipeline_config.kwargs["n_clusters"],
                    cache_dir=pipeline_config.kwargs["cache_dir"],
                ),
            )
        elif head_type == "svd_softmax":
            head = SVDSoftmaxLayer(
                model.lm_head.weight, **pipeline_config.kwargs
            )
        elif head_type == "vocab_pruning":
            if cli_args.dataset == "alpaca":
                pipeline_config.kwargs["dataset_name"] = "tatsu-lab/alpaca"
                pipeline_config.kwargs["dataset_column"] = "text"
            elif cli_args.dataset == "math":
                pipeline_config.kwargs["dataset_name"] = "lighteval/MATH-Hard"
                pipeline_config.kwargs["dataset_column"] = "problem"
            elif cli_args.dataset == "xnli":
                pipeline_config.kwargs["dataset_name"] = "facebook/xnli"
                pipeline_config.kwargs["dataset_column"] = "premise"

            pipeline_config.kwargs["dataset_split"] = "train"

            head = VocabPruningHead(model_id, **pipeline_config.kwargs)
            tokenizer = head.pruned_tokenizer
        elif head_type == "fgd":
            fgd_model, tokenizer = get_fgd_model_and_tokenizer(
                model_id, **pipeline_config.kwargs
            )
            head = fgd_model.lm_head

        evaluator = NextTokenEvaluator(
            baseline_model=model,
            custom_head=head,
            tokenizer=tokenizer,
            head_type=head_type,
        )

        accuracy_result = evaluator.evaluate(
            dataloader=get_dataloader(cli_args),
            max_new_tokens=128,
        )

        if accuracy_result is not None:
            results[head_type].append(
                {
                    "config": {
                        **pipeline_config.kwargs,
                        "device": torch.cuda.get_device_name(),
                    },
                    "accuracy": accuracy_result,
                }
            )

        print(json.dumps(results))


if __name__ == "__main__":
    main()
