"""Evaluate latency (GPU) of different efficient heads."""

import argparse
import json
import logging
from dataclasses import asdict
from pathlib import Path

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

from efficient_heads.eval_latency import measure_latency
from efficient_heads.experiments import (
    PIPELINE_CONSTRUCTORS,
    create_experiments,
)


def parse_cli_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Latency evaluation (GPU) of different efficient heads."
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Whether verbose output is enabled.",
    )
    parser.add_argument("--cluster-cache", help="Path to cluster cache.")
    parser.add_argument("--experiment", default="small")

    return parser.parse_args()


def get_alpaca_dataloader() -> DataLoader:
    """Return data loader for ``tatsu-lab/alpaca_eval``."""

    dataset = load_dataset("tatsu-lab/alpaca_eval", split="eval")

    return DataLoader(
        dataset,
        collate_fn=lambda batch: [
            item["instruction"] + "\n\n" + item.get("input", "")
            for item in batch
        ],
        batch_size=1,
    )


def main():
    """Evaluate different efficient heads."""
    cli_args = parse_cli_args()

    if cli_args.verbose:
        logging.basicConfig(level=logging.INFO)

    experiment = create_experiments(
        cli_args.experiment, Path(cli_args.cluster_cache)
    )

    results = {}

    experiment_iterator = tqdm(experiment)

    for pipeline_config in experiment_iterator:

        head_type = pipeline_config.head_type

        experiment_iterator.set_description(
            desc=f"{head_type} {pipeline_config.kwargs}"
        )

        pipeline = PIPELINE_CONSTRUCTORS[head_type](
            model_id="meta-llama/Llama-3.2-1B-Instruct",
            **pipeline_config.kwargs,
        )

        avg_results = measure_latency(
            pipeline, prompts=get_alpaca_dataloader()
        )

        if head_type not in results:
            results[head_type] = []

        result = {
            "config": {
                **pipeline_config.kwargs,
                "device": torch.cuda.get_device_name(),
            }
        }
        if avg_results is not None:
            result["profiling_results"] = asdict(avg_results)

        results[head_type].append(result)

        print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()
