"""Evaluate different LLMs against a baseline model's outputs."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import (
    Any,
    Callable,
    Generator,
    Sequence,
)

import torch
from torch.utils.data import DataLoader
from torcheval.metrics import Metric, MulticlassAccuracy
from torcheval.metrics.text import BLEUScore, WordErrorRate
from torchmetrics.functional import kl_divergence as tm_kl_divergence
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from efficient_heads.pipeline import NextTokenGenerator


class BaseEvaluator(ABC):
    """An evaluator base class to compare outputs from different models."""

    # Metric factory registry that can be extended easily
    _Metrics: dict[str, Metric] = {}

    def __init__(
        self,
        model_pipeline: pipeline,
        base_model_pipeline: pipeline,
        do_sample: bool = False,
    ):
        """Initialize with two different pipelines for comparison.

        :param model_pipeline:
            The model pipeline to evaluate which contains a tokenizer.
        :param base_model_pipeline:
            The base model pipeline to compare against.
        :param do_sample:
            Whether to use sampling for generation, defaults to False.
        """
        self.model_pipeline = model_pipeline
        self.base_model_pipeline = base_model_pipeline
        self.do_sample = do_sample

        for pipe in [self.model_pipeline, self.base_model_pipeline]:
            pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id
            pipe.model.generation_config.pad_token_id = (
                pipe.tokenizer.pad_token_id
            )

    def __init_subclass__(cls, **kwargs):
        """Subclass init so that `_Metrics` are not shared."""
        super().__init_subclass__(**kwargs)
        cls._Metrics = {}

    @classmethod
    def register_metric(cls, name: str) -> Callable[Metric, Metric]:
        """Decorator to register a `torcheval.Metric` to the class.

        This decorator can be used to define new metrics anywhere in your code
        and register it to the class as:

        @EvaluatorImplementation.register_metric(name="custom_metric")
        class CustomMetric(Metric[float]):
            def update(self, predictions: Sequence[str], targets: Sequence[str]):
                ...

            def compute(self) -> float:
                ...

            def merge_state(self, other: CustomMetric):
                ...
        """

        def decorator(metric_cls: Metric) -> Metric:
            if name in cls._Metrics:
                raise ValueError(f"Metric '{name}' is already registered.")
            cls._Metrics[name] = metric_cls
            return metric_cls

        return decorator

    @classmethod
    def list_registered_metrics(cls) -> Sequence[str]:
        """List the name of registered metrics.

        :return: The names of the registered metrics.
        """
        return list(cls._Metrics.keys())

    @abstractmethod
    def generate_outputs(
        self, dataloader: DataLoader, num_batches: int = None
    ) -> Generator[Any, Any, Any]:
        """Yield (prompts, predictions, targets) for metric comparisons."""

    def evaluate(
        self,
        dataloader: DataLoader,
        metrics: Sequence[str] | None = None,
        num_batches: int | None = None,
    ) -> dict[str, float]:
        """Evaluate a set of metrics using the given data.

        :param dataloader:
            A dataloader that generates batches of data with the first item
            being the input prompt to send to the models.
        :param metrics:
            A list of metrics that have been registered to evaluate. These
            metrics can be registered by the user anywhere in the code with the
        :param num_batches:
            The number of batches to use from the dataloader,
            defaults to None. If not set, the entire dataloader is
            iterated through.
        :return: A dictionary of results specifying the scores for each metric.
        """
        if metrics is None:
            metrics = self.list_registered_metrics()

        metric_objects: dict[str, Metric] = {
            name: self._Metrics[name]() for name in metrics
        }

        for _, predictions, targets in tqdm(
            self.generate_outputs(dataloader, num_batches)
        ):
            for name, metric in metric_objects.items():
                metric.update(predictions, targets)

        return {
            name: float(metric.compute())
            for name, metric in metric_objects.items()
        }


class TextEvaluator(BaseEvaluator):
    """An implementation of an evaluator for text outputs."""

    def __init__(self, *base_args, max_new_tokens: int = 128, **base_kwargs):
        """Initialize the text evaluator.

        :param max_tokens: The max number of tokens to use, defaults to 128.
        """
        super().__init__(*base_args, **base_kwargs)
        self.max_new_tokens = max_new_tokens

    def generate_outputs(
        self, dataloader: DataLoader, num_batches: int = None
    ):
        with torch.no_grad():
            for i, prompts in enumerate(dataloader):
                if num_batches and i >= num_batches:
                    break

                outputs_model = self.model_pipeline(
                    prompts,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=self.do_sample,
                )
                outputs_baseline = self.base_model_pipeline(
                    prompts,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=self.do_sample,
                )

                predictions = [
                    out[0]["generated_text"] for out in outputs_model
                ]
                targets = [
                    out[0]["generated_text"] for out in outputs_baseline
                ]

                yield prompts, predictions, targets


# Register text metrics
TextEvaluator.register_metric(name="wer")(WordErrorRate)


@TextEvaluator.register_metric(name="bleu1")
class BLEUScore1(BLEUScore):
    """Register the BLEUScore with n_gram=1."""

    def __init__(self):
        super().__init__(n_gram=1)


class LogitsEvaluator(BaseEvaluator):
    """
    An evaluator to compare logits from two different models.
    """

    def generate_outputs(
        self, dataloader: DataLoader, num_batches: int = None
    ):
        with torch.no_grad():
            for i, items in tqdm(enumerate(dataloader)):
                if num_batches and i >= num_batches:
                    break
                prompts = items[0]

                model_inputs = self.model_pipeline.tokenizer(
                    prompts, return_tensors="pt", padding=True, truncation=True
                ).to(self.model_pipeline.model.device)
                baseline_inputs = self.base_model_pipeline.tokenizer(
                    prompts, return_tensors="pt", padding=True, truncation=True
                ).to(self.base_model_pipeline.model.device)

                model_outputs = self.model_pipeline.model(**model_inputs)
                baseline_outputs = self.base_model_pipeline.model(
                    **baseline_inputs
                )

                model_logits = (
                    model_outputs.logits
                    if hasattr(model_outputs, "logits")
                    else model_outputs
                )
                baseline_logits = (
                    baseline_outputs.logits
                    if hasattr(baseline_outputs, "logits")
                    else baseline_outputs
                )

                yield (
                    prompts,
                    model_logits.squeeze(0),
                    baseline_logits.squeeze(0),
                )


@LogitsEvaluator.register_metric("kldiv")
class KLDivergence(Metric[float]):
    """KL Divergence metric."""

    def __init__(
        self,
    ) -> None:
        super().__init__()
        self.total_kl = 0.0
        self.count = 0

    def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None:
        """Update state."""
        # Assume raw logits and convert to probabilities
        p = torch.nn.functional.softmax(predictions, dim=-1)
        q = torch.nn.functional.softmax(targets, dim=-1)

        kl = tm_kl_divergence(p, q, log_prob=False, reduction="none")
        self.total_kl += kl.sum()
        self.count += p.size(0)

    def compute(self) -> float:
        """Compute the score."""
        return float(self.total_kl / self.count)

    def merge_state(self, metrics: Sequence[KLDivergence]):
        """Merge metrics.

        :param metrics:
            A sequence of metric objects of the same type to combine.
        """
        for metric in metrics:
            self.total_kl += metric.total_kl
            self.count += metric.count


class NextTokenEvaluator:
    """Evaluator for comparing next generated token by two different heads."""

    _Metrics: dict[str, type[Metric]] = {}

    def __init__(
        self,
        baseline_model: AutoModelForCausalLM,
        custom_head: torch.nn.Module,
        head_type: str,
        tokenizer: AutoTokenizer,
        do_sample: bool = False,
    ):
        self.baseline_model = baseline_model
        self.custom_head = custom_head
        self.head_type = head_type
        tokenizer.pad_token_id = tokenizer.eos_token_id

        self.tokenizer = tokenizer
        self.do_sample = do_sample

        self.pipeline = NextTokenGenerator(
            baseline_model=baseline_model,
            head=custom_head,
            tokenizer=tokenizer,
        )

    @classmethod
    def register_metric(
        cls, name: str
    ) -> Callable[[type[Metric]], type[Metric]]:
        def decorator(metric_cls: type[Metric]) -> type[Metric]:
            if name in cls._Metrics:
                raise ValueError(f"Metric '{name}' is already registered.")
            cls._Metrics[name] = metric_cls
            return metric_cls

        return decorator

    @classmethod
    def list_registered_metrics(cls) -> Sequence[str]:
        return list(cls._Metrics.keys())

    def evaluate(
        self,
        dataloader: DataLoader,
        metrics: Sequence[str] | None = None,
        num_batches: int | None = None,
        max_new_tokens: int = 128,
    ) -> dict[str, float]:
        if metrics is None:
            metrics = self.list_registered_metrics()

        metric_objects = {name: self._Metrics[name]() for name in metrics}

        batch_iter = enumerate(dataloader)
        max_batch_iter = None
        xnli = False
        if dataloader.dataset.config_name == "all_languages":
            xnli = True
            max_batch_iter = 100

        if num_batches is not None:
            batch_iter = tqdm(
                batch_iter,
                total=num_batches,
                desc=f"Evaluating {num_batches} batches for {self.head_type}:",
            )
        else:
            batch_iter = tqdm(batch_iter, desc="Evaluating all batches")

        top1_total = 0
        top2_total = 0
        top3_total = 0
        top5_total = 0

        top1_correct = 0
        top2_correct = 0
        top3_correct = 0
        top5_correct = 0

        for batch_idx, input_prompts in batch_iter:
            if max_batch_iter is not None and batch_idx == max_batch_iter:
                break
            if xnli:
                for input_prompt in input_prompts[0].values():
                    if num_batches is not None and batch_idx >= num_batches:
                        break
                    outputs_head, outputs_baseline, top_10_token_baseline = (
                        self.pipeline(
                            input_prompt,
                            max_new_tokens=max_new_tokens,
                            do_sample=self.do_sample,
                        )
                    )
                    if outputs_head is None:
                        continue

                    # Batch update metrics at once
                    pred_cat = outputs_head
                    ref_cat = outputs_baseline

                    for name, metric in metric_objects.items():
                        metric.update(pred_cat, ref_cat)

                    for head_pred, base_pred in zip(
                        outputs_head, top_10_token_baseline
                    ):
                        top1_total += 1
                        top2_total += 1
                        top3_total += 1
                        top5_total += 1

                        if head_pred == base_pred[0]:
                            top1_correct += 1
                            top2_correct += 1
                            top3_correct += 1
                            top5_correct += 1
                            continue
                        if head_pred in base_pred[:2]:
                            top2_correct += 1
                            top3_correct += 1
                            top5_correct += 1
                            continue
                        if head_pred in base_pred[:3]:
                            top3_correct += 1
                            top5_correct += 1
                            continue
                        if head_pred in base_pred[:5]:
                            top5_correct += 1
                            continue
            else:
                if num_batches is not None and batch_idx >= num_batches:
                    break
                outputs_head, outputs_baseline, top_10_token_baseline = (
                    self.pipeline(
                        input_prompts,
                        max_new_tokens=max_new_tokens,
                        do_sample=self.do_sample,
                    )
                )
                if outputs_head is None:
                    continue

                # Batch update metrics at once
                pred_cat = outputs_head
                ref_cat = outputs_baseline

                for name, metric in metric_objects.items():
                    metric.update(pred_cat, ref_cat)

                for head_pred, base_pred in zip(
                    outputs_head, top_10_token_baseline
                ):
                    top1_total += 1
                    top2_total += 1
                    top3_total += 1
                    top5_total += 1

                    if head_pred == base_pred[0]:
                        top1_correct += 1
                        top2_correct += 1
                        top3_correct += 1
                        top5_correct += 1
                        continue
                    if head_pred in base_pred[:2]:
                        top2_correct += 1
                        top3_correct += 1
                        top5_correct += 1
                        continue
                    if head_pred in base_pred[:3]:
                        top3_correct += 1
                        top5_correct += 1
                        continue
                    if head_pred in base_pred[:5]:
                        top5_correct += 1
                        continue

        result = {
            name: float(metric.compute())
            for name, metric in metric_objects.items()
        }
        result["top1_acc"] = top1_correct / top1_total
        result["top2_acc"] = top2_correct / top2_total
        result["top3_acc"] = top3_correct / top3_total
        result["top5_acc"] = top5_correct / top5_total
        return result


NextTokenEvaluator.register_metric("accuracy")(MulticlassAccuracy)
