"""Evaluate latency of different efficient heads."""

import time
from dataclasses import dataclass
from textwrap import dedent
from typing import Optional

import torch

# from snippets.llm.huggingface.quantize import quantize_head, quantize_model
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer


def _bold(input_str, /) -> str:
    return "\x1b[1m" + input_str + "\x1b[0m"


def _underline(input_str, /) -> str:
    return "\x1b[4m" + input_str + "\x1b[0m"


@dataclass
class ProfilingResults:
    """
    Contains benchmark metrics from model profiling during generation.
    """

    #: Number of tokens produced during decoding.
    num_tokens_generated: int

    #: Total inference time (latency) of the model in milliseconds.
    total_ms: float

    #: Time to first token (TTFT) of the model in milliseconds.
    ttft_ms: float

    #: Tokens per seconds (TPS) of the model.
    tps: float

    #: Tokens per seconds (TPS) of the model body.
    tps_body: float

    #: Tokens per seconds (TPS) of the model head.
    tps_head: float

    #: Time per output token (TPOT) of the model in milliseconds.
    #: TPOT does not include TTFT.
    tpot_ms: float

    #: Time per output token (TPOT) of the model body in milliseconds.
    #: TPOT does not include TTFT.
    tpot_body_ms: float

    #: Time per output token (TPOT) of the model head in milliseconds.
    #: TPOT does not include TTFT.
    tpot_head_ms: float

    def __str__(self):
        return dedent(
            f"""
            {_bold('Tokens generated')} {self.num_tokens_generated}
            {_underline('Model metrics')}
            \t{_bold('Total inference time (latency)')} {self.total_ms:.4f}ms
            \t{_bold('TTFT')} {self.ttft_ms:.4f}ms
            \t{_bold('TPS')} {self.tps:.4f}
            \t{_bold('TPOT')} {self.tpot_ms:.4f}ms
            {_underline('Body metrics')}
            \t{_bold('TPS')} {self.tps_body:.4f}
            \t{_bold('TPOT')} {self.tpot_body_ms:.4f}ms
            {_underline('Head metrics')}
            \t{_bold('TPS')} {self.tps_head:.4f}
            \t{_bold('TPOT')} {self.tpot_head_ms:.4f}ms
            """
        )


@dataclass
class GenerationResults:
    """
    Contains metrics and results from generating model outputs.
    """

    #: Output text of the model.
    generated_text: torch.Tensor

    #: Number of tokens produced during decoding.
    num_tokens_generated: int

    #: Benchmark metrics from model profiling during generation.
    profiling_metrics: Optional[ProfilingResults] = None

    def __str__(self):
        return (
            f"{_bold(f'Model output ({self.num_tokens_generated} tokens):')}"
            + str(self.generated_text)
            + "\n"
            + str(self.profiling_metrics)
        )


class GenerationPipeline:
    """
    Generate text and measure latency.

    :param model_body:
        The model body.
    :param model_head:
        The model head.
    :param tokenizer:
        The tokenizer.
    :param mode:
        The generation mode. If set to ``"flash_head"``, it will treat it
        accordingly, for all other modes it should be ``"standard"``.
    """

    def __init__(
        self,
        model_body: nn.Module,
        model_head: nn.Module,
        tokenizer: AutoTokenizer,
        mode: str = "standard",
    ):
        self.device = model_body.device
        self.mode = mode
        self.model_body = model_body
        self.model_head = model_head
        self.tokenizer = tokenizer

        self.is_flash_head = mode == "flash_head"
        self.is_midx = mode == "midx"

    def get_next_token_standard(
        self,
        logits: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
    ):
        """
        Generate the next token according to `logits`

        This will always generate the next token according to the "standard" way, by
        considering the `logits` of the model as the full logits.

        :param logits:
            Logits from the output of a model.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, applicable only if `do_sample` is ``True``.

        :return:
            The next token index.
        """
        if do_sample:
            probs = (logits[:, -1, :] / temperature).softmax(dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            next_token = logits[:, -1:].argmax(dim=-1)
        return next_token

    def get_next_token(
        self,
        logits: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """
        Generate the next token according to `logits`

        :param logits:
            Logits from the output of a model.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, applicable only if `do_sample` is ``True``.
        :return:
            The next token index.
        """
        if self.is_flash_head or self.is_midx:
            return self.model_head.get_next_token(
                logits,
                do_sample=do_sample,
                temperature=temperature,
            )

        logits = self.model_head(logits)
        return self.get_next_token_standard(logits, do_sample, temperature)

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        max_new_tokens: int,
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> GenerationResults:
        """
        Generate text and measure throughput.

        :param input_ids:
            The input prompt.
        :param attention_mask:
            The starting attention mask.
        :param max_new_tokens:
            The maximum number of tokens allowed to be generated, will
            otherwise cancel when end-of-sequence token is reached.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, only applicable if `do_sample` is ``True``.

        :return:
            Generated output tokens and metrics.

        :raises RuntimeError:
            If the first token is an end-of-sentence (EOS) token.
        """
        eot_token_id = self.tokenizer.eos_token_id
        use_cpu = False  # self.model_head.device.type == "cpu"

        if not use_cpu:

            first_token_start_event = torch.cuda.Event(enable_timing=True)
            first_token_end_event = torch.cuda.Event(enable_timing=True)

            first_token_head_start_event = torch.cuda.Event(enable_timing=True)
            first_token_head_end_event = torch.cuda.Event(enable_timing=True)
            first_token_body_start_event = torch.cuda.Event(enable_timing=True)
            first_token_body_end_event = torch.cuda.Event(enable_timing=True)

        with torch.no_grad():
            if use_cpu:
                first_token_start_event = time.time()
                first_token_body_start_event = time.time()
            else:
                first_token_start_event.record()
                first_token_body_start_event.record()

            outputs = self.model_body(
                input_ids, attention_mask=attention_mask, use_cache=True
            )
            if use_cpu:
                ttft_body = time.time() - first_token_body_start_event
                first_token_head_start_event = time.time()
            else:
                first_token_body_end_event.record()

                first_token_head_start_event.record()
            original_head = self.model_head
            if self.is_flash_head:
                original_head = original_head.original_lm_head
            logits = original_head(outputs.last_hidden_state)
            next_token = self.get_next_token_standard(
                logits, do_sample, temperature
            )
            if use_cpu:
                ttft_head = time.time() - first_token_head_start_event
                ttft_ms = time.time() - first_token_start_event
            else:
                first_token_head_end_event.record()
                first_token_end_event.record()

            past_key_values = outputs.past_key_values

            if use_cpu:
                total_time_start_event = time.time()
            else:
                total_time_start_event = torch.cuda.Event(enable_timing=True)
                total_time_end_event = torch.cuda.Event(enable_timing=True)

                total_time_start_event.record()

            if next_token.item() == eot_token_id:
                if not use_cpu:
                    total_time_end_event.record()
                    torch.cuda.synchronize()
                print("Token generation unsuccessful.")
                return None

            generated = next_token
            current_token = next_token
            total_body_time = 0.0
            total_head_time = 0.0

            for _ in range(max_new_tokens - 1):
                if current_token.item() == eot_token_id:
                    break

                if use_cpu:
                    start_body_event = time.time()
                else:
                    start_body_event = torch.cuda.Event(enable_timing=True)
                    end_body_event = torch.cuda.Event(enable_timing=True)

                    start_body_event.record()

                outputs = self.model_body(
                    current_token,
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    use_cache=True,
                )
                if use_cpu:
                    end_body_event = time.time() - start_body_event
                    total_body_time += end_body_event
                else:
                    end_body_event.record()
                    torch.cuda.synchronize()
                    total_body_time += start_body_event.elapsed_time(
                        end_body_event
                    )

                past_key_values = outputs.past_key_values

                next_token = self.get_next_token(
                    outputs.last_hidden_state, do_sample, temperature
                )
                if use_cpu:
                    start_head_event = time.time()
                else:
                    start_head_event = torch.cuda.Event(enable_timing=True)
                    end_head_event = torch.cuda.Event(enable_timing=True)

                    start_head_event.record()

                num_measurements = 100
                for ind in range(num_measurements):
                    next_token = self.get_next_token(
                        outputs.last_hidden_state, do_sample, temperature
                    )
                if use_cpu:
                    end_head_event = time.time() - start_head_event
                    total_head_time += end_head_event / num_measurements
                else:
                    end_head_event.record()
                    torch.cuda.synchronize()
                    total_head_time += (
                        start_head_event.elapsed_time(end_head_event)
                        / num_measurements
                    )

                generated = torch.cat([generated, next_token], dim=-1)
                current_token = next_token
                attention_mask = torch.cat(
                    [attention_mask, attention_mask.new_ones((1, 1))], dim=-1
                )

                if next_token.item() == eot_token_id:
                    break

        if use_cpu:
            total_time_ms = time.time() - total_time_start_event
        else:
            total_time_end_event.record()
            torch.cuda.synchronize()

            ttft_head = first_token_head_start_event.elapsed_time(
                first_token_head_end_event
            )
            ttft_body = first_token_body_start_event.elapsed_time(
                first_token_body_end_event
            )

            total_time_ms = total_time_start_event.elapsed_time(
                total_time_end_event
            )
            ttft_ms = first_token_start_event.elapsed_time(
                first_token_end_event
            )
        num_tokens_generated = generated.shape[1]

        return GenerationResults(
            generated_text=self.tokenizer.decode(
                generated[0], skip_special_tokens=True
            ),
            num_tokens_generated=num_tokens_generated,
            profiling_metrics=ProfilingResults(
                num_tokens_generated=num_tokens_generated,
                total_ms=total_time_ms,
                tps=num_tokens_generated * 1000 / total_time_ms,
                ttft_ms=ttft_ms,
                tps_head=num_tokens_generated
                * 1000
                / (total_head_time + ttft_head),
                tps_body=num_tokens_generated
                * 1000
                / (total_body_time + ttft_body),
                # TPOT does not include TTFT.
                tpot_ms=(total_time_ms - ttft_ms) / (num_tokens_generated - 1),
                tpot_head_ms=total_head_time / (num_tokens_generated - 1),
                tpot_body_ms=total_body_time / (num_tokens_generated - 1),
            ),
        )

    def __call__(
        self,
        prompt: str,
        max_new_tokens: int = 32,
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> GenerationResults:
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
            self.device
        )
        attention_mask = torch.ones_like(input_ids)

        return self.generate(
            input_ids, attention_mask, max_new_tokens, do_sample, temperature
        )


def get_standard_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    device_map: str = "cuda",
) -> GenerationPipeline:
    """Get the standard generation pipeline.

    :param model_id:
        The HuggingFace model, defaults to "meta-llama/Llama-3.2-1B-Instruct".
    :param device_map:
        The device to load the model at.
    :return:
        A standard model generation pipeline.
    """
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device_map
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    quantize = False
    if quantize:
        model.lm_head = quantize_head(model.lm_head)
        # model.model = quantize_model(model.model)

    generation_pipeline = GenerationPipeline(
        model.model,
        model.lm_head,
        tokenizer=tokenizer,
    )

    return generation_pipeline


class NextTokenGenerator:
    """Next token generator for heads with a get_next_token method."""

    def __init__(
        self,
        baseline_model: nn.Module,
        head: nn.Module,
        tokenizer: AutoTokenizer,
    ):
        self.device = baseline_model.model.device
        self.baseline_model = baseline_model
        self.model_body = baseline_model.model
        self.model_head = head
        self.tokenizer = tokenizer

    def get_top_k_token_standard(
        self,
        logits: torch.Tensor,
        k: int = 3,
        do_sample: bool = False,
        temperature: float = 1.0,
    ):
        """
        Generate the topk token according to `logits`.

        This will always generate the next token according to the "standard" way, by
        considering the `logits` of the model as the full logits.

        :param logits:
            Logits from the output of a model.
        :param k:
            The number of to tokens to return.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, applicable only if `do_sample` is ``True``.

        :return:
            The next token index.
        """
        if do_sample:
            probs = (logits[:, -1, :] / temperature).softmax(dim=-1)
            next_tokens = torch.multinomial(
                probs, num_samples=k, replacement=False
            )
        else:
            # use_identical_tiebreak
            vals, idx = torch.sort(
                logits[:, -1, :], dim=-1, descending=True, stable=True
            )
            next_tokens = idx[..., :k]

        return next_tokens

    def get_next_token_standard(
        self,
        logits: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
    ):
        """
        Generate the next token according to `logits`.

        This will always generate the next token according to the "standard" way, by
        considering the `logits` of the model as the full logits.

        :param logits:
            Logits from the output of a model.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, applicable only if `do_sample` is ``True``.

        :return:
            The next token index.
        """
        if do_sample:
            probs = (logits[:, -1, :] / temperature).softmax(dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            next_token = logits[:, -1:].argmax(dim=-1)
        return next_token

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        max_new_tokens: int,
        do_sample: bool = False,
        temperature: float = 1.0,
    ):
        """
        Generate text and measure throughput.

        :param input_ids:
            The input prompt.
        :param attention_mask:
            The starting attention mask.
        :param max_new_tokens:
            The maximum number of tokens allowed to be generated, will
            otherwise cancel when end-of-sequence token is reached.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, only applicable if `do_sample` is ``True``.

        :return:
            Generated output tokens and metrics.
        """
        eot_token_ids = self.baseline_model.generation_config.eos_token_id
        original_head = self.baseline_model.lm_head

        with torch.no_grad():
            # Generate the outputs to get the first token
            outputs = self.model_body(
                input_ids, attention_mask=attention_mask, use_cache=True
            )

            logits = original_head(outputs.last_hidden_state)
            next_token_baseline = self.get_next_token_standard(
                logits, do_sample, temperature
            )
            top_k_token_baseline = self.get_top_k_token_standard(
                logits,
                do_sample=do_sample,
                k=5,
                temperature=temperature,
            )
            generated_top10_baseline = [top_k_token_baseline.ravel()]

            next_token_head = self.model_head.get_next_token(
                outputs.last_hidden_state[:, -1:, :],
                do_sample=do_sample,
                temperature=temperature,
            )

            generated_baseline = next_token_baseline
            generated_head = next_token_head

            if next_token_baseline.item() in eot_token_ids:
                return (
                    generated_head.ravel(),
                    generated_baseline.ravel(),
                    generated_top10_baseline,
                )

            for _ in range(max_new_tokens - 1):
                past_key_values = outputs.past_key_values

                attention_mask = torch.cat(
                    [attention_mask, attention_mask.new_ones((1, 1))],
                    dim=-1,
                )
                outputs = self.model_body(
                    next_token_baseline,
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    use_cache=True,
                )

                logits = original_head(outputs.last_hidden_state)
                next_token_baseline = self.get_next_token_standard(
                    logits, do_sample, temperature
                )
                top_5_token_baseline = self.get_top_k_token_standard(
                    logits,
                    do_sample=do_sample,
                    k=5,
                    temperature=temperature,
                )
                generated_top10_baseline.append(top_5_token_baseline.ravel())

                next_token_head = self.model_head.get_next_token(
                    outputs.last_hidden_state[:, -1:, :],
                    do_sample=do_sample,
                    temperature=temperature,
                )
                generated_head = torch.cat(
                    [generated_head, next_token_head], dim=-1
                )
                generated_baseline = torch.cat(
                    [generated_baseline, next_token_baseline], dim=-1
                )

                if next_token_baseline.item() in eot_token_ids:
                    break
        return (
            generated_head.ravel(),
            generated_baseline.ravel(),
            generated_top10_baseline,
        )

    def __call__(
        self,
        prompt: str,
        max_new_tokens: int = 32,
        do_sample: bool = False,
        temperature: float = 1.0,
    ):
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
            self.device
        )
        attention_mask = torch.Tensor(input_ids)
        return self.generate(
            input_ids, attention_mask, max_new_tokens, do_sample, temperature
        )
