"""Evaluate latency of different efficient heads."""

from dataclasses import asdict
from logging import getLogger
from typing import Dict, List, Optional, Union

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from efficient_heads.pipeline import (
    GenerationPipeline,
    GenerationResults,
    ProfilingResults,
)

logger = getLogger(__name__)


DEFAULT_PROMPTS = [
    "Explain how dichlorodifluoromethane interacts with hydrochlorofluorocarbons in the stratosphere.",
    "Explain how the Schwarzschild metric describes spacetime curvature near a non-rotating black hole.",
    "Explain how the Nasdaq-100's PowerShares QQQ Trust affects market volatility.",
    "The following are multiple choice questions about biology. Think step by step and then finish your answer with 'the answer is (X)' where X is the correct letter choice.\n\nQ: Which of the following best describes the role of carbonic anhydrase in the human body?\nA) It catalyzes the conversion of CO2 and H2O to carbonic acid\nB) It breaks down glucose in red blood cells\nC) It facilitates lipid absorption in the small intestine\nD) It converts amino acids to proteins",
    "The following are multiple choice questions about physics. Think step by step and then finish your answer with 'the answer is (X)' where X is the correct letter choice.\n\nQ: A quantum harmonic oscillator is in its ground state. What is the probability of finding the particle at its equilibrium position?\nA) Zero\nB) Maximum\nC) Equal to finding it anywhere else\nD) Depends on temperature",
    "The following are multiple choice questions about computer science. Think step by step and then finish your answer with 'the answer is (X)' where X is the correct letter choice.\n\nQ: Which of the following sorting algorithms has the best worst-case time complexity?\nA) Quicksort\nB) Mergesort\nC) Insertion sort\nD) Bubble sort",
]


def _avg_metric(profiling_results: List[ProfilingResults], metric: str):
    """Compute the average of `metric` given `profiling_results`."""

    return sum(
        asdict(profiling_result)[metric]
        for profiling_result in profiling_results
    ) / len(profiling_results)


def _avg_metrics(
    profiling_results: List[ProfilingResults],
) -> ProfilingResults:

    avg_metrics = {}
    for metric in list(profiling_results[0].__annotations__.keys()):
        avg_metrics[metric] = _avg_metric(profiling_results, metric)

    return ProfilingResults(**avg_metrics)


def measure_latency(
    pipe: GenerationPipeline,
    prompts: Optional[Union[List[str], DataLoader]] = None,
    num_eval_prompts: Optional[int] = 100,
    num_warmup_prompts: int = 3,
    max_new_tokens: int = 128,
) -> ProfilingResults:
    """
    Measure performance metrics for a single pipeline.

    :param pipe:
        The generation pipeline to test.
    :param prompts:
        List of prompts to test on (default provides a sample prompt).
    :param num_eval_prompts:
        The number of prompts in `prompts` to be used for evaluation.
    :param num_warmup_prompts:
        The number of prompts in `prompts` to be used for warmup.

    :return:
        Summary of profiling results from model profiling.
    """
    prompts = prompts or DEFAULT_PROMPTS

    mode_name = pipe.mode.capitalize()
    logger.info(f"\nRunning {mode_name} generation tests:")
    logger.info("=" * (len(f"\nRunning {mode_name} generation tests:")))

    profiling_results: List[ProfilingResults] = []

    # warmup
    for idx, prompt in tqdm(enumerate(prompts)):
        if idx < num_warmup_prompts:
            logger.info("Warmup prompt ...")
            pipe(prompt, max_new_tokens=max_new_tokens)
            continue
        if (
            num_eval_prompts is not None
            and idx == num_eval_prompts + num_warmup_prompts + 1
        ):
            break

        logger.info(f"\nPrompt: {prompt}")
        result: GenerationResults = pipe(
            prompt,
            max_new_tokens=max_new_tokens,
        )
        logger.info(result)
        if result is None:
            continue

        profiling_results.append(result.profiling_metrics)

    return _avg_metrics(profiling_results)


def compare_outputs(
    primary_pipe: GenerationPipeline,
    comparison_pipe: GenerationPipeline,
    prompts: Optional[List[str]] = None,
) -> Dict:
    """
    Compare outputs between two pipelines.

    :param primary_pipe:
        The primary pipeline to use as reference.
    :param comparison_pipe:
        The pipeline to compare against the primary.
    :param prompts:
        List of prompts to test on (default provides a sample prompt).

    :return:
        Dictionary with comparison metrics.
    """
    prompts = prompts or DEFAULT_PROMPTS

    primary_name = primary_pipe.mode.capitalize()
    comparison_name = comparison_pipe.mode.capitalize()

    logger.info(f"\nComparing {primary_name} vs {comparison_name} Outputs:")
    logger.info("=" * (len(primary_name) + len(comparison_name) + 15))

    comparison_results = []

    for prompt in prompts:
        logger.info(f"\nPrompt: {prompt}")

        # Get tokenized input
        input_ids = primary_pipe.tokenizer(
            prompt, return_tensors="pt"
        ).input_ids.to(primary_pipe.device)
        attention_mask = torch.ones_like(input_ids)

        # Generate with both pipelines
        result = compare_generations(
            primary_pipe,
            comparison_pipe,
            input_ids,
            attention_mask,
            max_new_tokens=512,
        )

        metrics = result["metrics"]
        comparison_results.append(metrics)

        logger.info(f"\n{primary_name} vs {comparison_name} comparison:")
        logger.info(f"{metrics['marked_text']}")
        logger.info(
            f"Agreement: {metrics['agreements']}/{metrics['tokens_generated']} "
            f"({metrics['agreement_percentage']:.1f}%)"
        )

    avg_agreement = (
        sum(m["agreement_percentage"] for m in comparison_results)
        / len(comparison_results)
        if comparison_results
        else 0
    )

    logger.info(f"\nAverage agreement: {avg_agreement:.1f}%")

    return {
        "comparison": comparison_results,
        "average_agreement": avg_agreement,
    }


def compare_generations(
    primary_pipe: GenerationPipeline,
    comparison_pipe: GenerationPipeline,
    input_ids: List[str],
    attention_mask: torch.Tensor,
    max_new_tokens: int,
    do_sample: bool = False,
    temperature: float = 1.0,
) -> Dict:
    """
    Compare token generation between two pipelines.

    :param primary_pipe:
        The primary pipeline.
    :param comparison_pipe:
        The pipeline to compare against.
    :param input_ids:
        Tokenized input.
    :param attention_mask:
        Attention mask for the input.
    :param max_new_tokens:
        Maximum number of tokens to generate.
    :param do_sample:
        Whether to use sampling.
    :param temperature:
        Temperature for sampling.

    :return:
        Dictionary with comparison metrics.
    """
    primary_tokens = []
    comparison_tokens = []
    token_matches = []
    eot_token_id = primary_pipe.tokenizer.eos_token_id

    with torch.no_grad():
        # Process the prompt through the primary model
        primary_outputs = primary_pipe.model_body(
            input_ids, attention_mask=attention_mask, use_cache=True
        )

        # Determine the correct head for the primary model
        primary_head = primary_pipe.model_head

        primary_logits = primary_head(primary_outputs.last_hidden_state)
        next_token = primary_pipe.get_next_token_standard(
            primary_logits, do_sample, temperature
        )

        if next_token.item() == eot_token_id:
            return {
                "generated": input_ids,
                "metrics": {
                    "tokens_generated": 0,
                    "agreements": 0,
                    "agreement_percentage": 0,
                    "marked_text": "",
                },
            }

        # First will always be same assuming we're using the same prompt
        primary_tokens.append(next_token)
        current_token = next_token
        if comparison_pipe.is_flash_head:
            comparison_tokens.append(next_token)  # Initially the same
        else:
            next_comparison_token = comparison_pipe.get_next_token_standard(
                primary_logits, do_sample, temperature
            )
            comparison_tokens.append(next_comparison_token)

        token_matches.append(True)

        generated = torch.cat([input_ids, next_token], dim=-1)
        past_key_values = primary_outputs.past_key_values

        for _ in range(max_new_tokens - 1):
            if current_token.item() == eot_token_id:
                break

            # Generate next token with primary model
            primary_outputs = primary_pipe.model_body(
                current_token,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=True,
            )
            past_key_values = primary_outputs.past_key_values
            primary_logits = primary_head(primary_outputs.last_hidden_state)
            primary_next_token = primary_pipe.get_next_token_standard(
                primary_logits, do_sample, temperature
            )

            if primary_next_token.item() == eot_token_id:
                break

            comparison_next_token = comparison_pipe.get_next_token(
                primary_outputs.last_hidden_state, do_sample, temperature
            )

            primary_decoded = primary_pipe.tokenizer.decode(
                primary_next_token.item()
            )
            comparison_decoded = comparison_pipe.tokenizer.decode(
                comparison_next_token.item()
            )
            match = primary_decoded == comparison_decoded

            primary_tokens.append(primary_next_token)
            comparison_tokens.append(comparison_next_token)
            token_matches.append(match)

            generated = torch.cat([generated, primary_next_token], dim=-1)
            current_token = primary_next_token

            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((1, 1))], dim=-1
            )

    primary_decoded = [
        primary_pipe.tokenizer.decode(token.item())
        for token in torch.stack(primary_tokens, dim=0)
    ]
    comparison_decoded = [
        comparison_pipe.tokenizer.decode(token.item())
        for token in torch.stack(comparison_tokens, dim=0)
    ]

    marked_text = []
    for primary_token, comparison_token, matches in zip(
        primary_decoded, comparison_decoded, token_matches
    ):
        marked_text.append(
            primary_token
            if matches
            else f"([{primary_token}] [{comparison_token}])"
        )

    agreement_count = sum(1 for m in token_matches if m)

    return {
        "generated": generated,
        "metrics": {
            "tokens_generated": len(token_matches),
            "agreements": agreement_count,
            "agreement_percentage": (
                (agreement_count / len(token_matches) * 100)
                if token_matches
                else 0
            ),
            "marked_text": "".join(marked_text),
        },
    }
