import argparse

import torch
from snippets.llm.huggingface.quantize import quantize_model
from transformers import pipeline

from efficient_heads.eval_latency import (
    compare_outputs,
    measure_latency,
)
from efficient_heads.flash_head import FlashHead, get_flash_head_parameters
from efficient_heads.pipeline import GenerationPipeline

CACHE_FILENAME = "clustering_cache.pt"
CACHE_CONFIG_FILENAME = "clustering_config.json"


def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Replace the classification head of a model with FlashHead."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="meta-llama/Llama-3.2-1B-Instruct",
        help="The model ID to use for text generation",
    )
    parser.add_argument(
        "--n_clusters",
        type=int,
        default=8192,
        help="Number of clusters",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default="/Llama-3.2-1B-Instruct-cluster-8192/",
        help="Directory to cache the clustering results",
    )
    parser.add_argument(
        "--forward_type",
        type=str,
        choices=[
            "partial_logits",
            "monte_carlo_full_logits",
            "approximated_full_logits",
        ],
        default="partial_logits",
        help="What type of logits to produce in the forward pass.",
    )
    parser.add_argument(
        "--n_probes",
        type=int,
        default=256,
        help="Number of clusters to probe during search",
    )
    parser.add_argument(
        "--evaluate", action="store_true", help="Run evaluation after setup"
    )
    parser.add_argument(
        "--eval_latency", action="store_true", help="Run latency evaluation"
    )
    parser.add_argument(
        "--quick_eval",
        action="store_true",
        help="Run quick evaluation instead of full evaluation",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()

    model_id = args.model
    n_clusters = args.n_clusters
    cache_dir = args.cache_dir
    n_probes = args.n_probes
    forward_type = args.forward_type

    print(f"Running with parameters:")
    print(f"  Model: {model_id}")
    print(f"  Number of clusters: {n_clusters}")
    print(f"  Cache directory: {cache_dir}")
    print(f"  Number of probes: {n_probes}")
    print(f"  Forward Type: {forward_type}")

    with torch.no_grad():
        std_pipe = pipeline(
            "text-generation",
            model=model_id,
            torch_dtype=torch.bfloat16,
            device_map="cuda",
        )

        flash_head_pipe = pipeline(
            "text-generation",
            model=model_id,
            torch_dtype=torch.bfloat16,
            device_map="cuda",
        )

        flash_head_pipe.model.lm_head = FlashHead(
            flash_head_pipe.model.lm_head,
            n_probes=n_probes,
            forward_type=forward_type,
            **get_flash_head_parameters(
                lm_head=flash_head_pipe.model.lm_head,
                cache_dir=cache_dir,
                n_clusters=n_clusters,
                tokenizer=flash_head_pipe.tokenizer,
                # # Optionally, you can add special tokens as their own isolated clusters.
                # special_token_types = [
                #     'bos_token',
                #     'eos_token',
                #     'unk_token',
                #     'pad_token',
                #     'sep_token',
                #     'cls_token',
                #     'mask_token',
                # ]
            ),
        )

        if True:  # args.eval_latency:

            # pruned_pipe = get_pruned_pipe()

            standard_generator = GenerationPipeline(
                std_pipe.model.model,
                std_pipe.model.lm_head,
                tokenizer=std_pipe.tokenizer,
            )
            # pruned_generator = GenerationPipeline(
            #     pruned_pipe.model.model,
            #     pruned_pipe.model.lm_head,
            #     tokenizer=pruned_pipe.tokenizer,
            #     mode='pruned',
            # )

            flash_head_generator = GenerationPipeline(
                flash_head_pipe.model.model,
                flash_head_pipe.model.lm_head,
                tokenizer=flash_head_pipe.tokenizer,
                mode="flash_head",
            )

            measure_latency(standard_generator)
            measure_latency(flash_head_generator)
            # measure_latency(pruned_generator)
            compare_outputs(standard_generator, flash_head_generator)
            # compare_outputs(standard_generator, pruned_generator)

        if args.evaluate:
            # FIXME
            from eval import run_evaluation, run_quick_eval

            if args.quick_eval:
                results = run_quick_eval(flash_head_pipe)
            else:
                results = run_evaluation(flash_head_pipe)

            print("Evaluation Results:")
            print(results)
