import logging
from typing import Any, Callable, Mapping, cast

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import Dataset
from transformers import (
    BatchEncoding,
    PreTrainedModel,
    TrainerCallback,
    TrainerState,
)

from lib_llm.inference import INFERENCE_DEVICE, PredictionConfig, predict

from .metric import MetricArg, SequenceMetric, TokenMetric, TokenValues


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()
# ACCELERATOR = "cuda:0" if HAS_CUDA else "cpu"
RESULT_DEVICE = torch.device("cpu")  # always use the CPU for metric evaluation


class EvaluationTask(TrainerCallback):
    def __init__(
        self,
        *,
        index_names: list[str] = ["sequence"],
        eval_condition: Callable[[TrainerState], bool] = lambda _: True,
        evaluate_args: list[str] = ["model"],
    ) -> None:
        self.index_names = index_names
        self.eval_condition = eval_condition
        self.evaluate_args = evaluate_args
        self.step_results = []

    def evaluate(self, *args, **kwargs) -> pd.DataFrame:
        """Evaluate all metrics on the given model."""
        raise NotImplementedError

    def on_evaluate(self, args, state, control, **kwargs) -> None:
        """Called after on each evaluation pass by the Trainer."""
        if not self.eval_condition(state):
            return
        eval_args = {
            arg_name: kwargs[arg_name]
            for arg_name in kwargs
            if arg_name in self.evaluate_args
        }
        res = self.evaluate(**eval_args)
        epoch = state.epoch if state.epoch is not None else 0
        self.step_results.append((epoch, res))

    def result(self) -> pd.DataFrame:
        """Get the final result, i.e. the concatenation of all metrics
        computed on each evaluation pass.
        """
        epochs, results = zip(*self.step_results)
        index_names = results[0].index.names
        return pd.concat(
            results,
            keys=epochs,
            axis=0,
            names=["epoch", *index_names],
        )


class SequenceEvaluationTask(EvaluationTask):
    """This object is meant to be passed as a callback to the
    Huggingface Trainer class. It is responsible for computing
    evaluation metrics on each evaluation pass for each sequence
    in an evaluation dataset.
    It implements caching of the model output, so that subsequent metrics
    can reuse the output of the model for the same sequence.
    """

    def __init__(
        self,
        metrics: Mapping[str, SequenceMetric],
        data: Dataset | tuple[list[list[str]], BatchEncoding],
        *,
        index_names: list[str] = ["sequence"],
        eval_condition: Callable[[TrainerState], bool] = lambda _: True,
        evaluate_args: list[str] = ["model"],
        inference_device: torch.device = INFERENCE_DEVICE,
    ) -> None:
        super().__init__(
            index_names=index_names,
            eval_condition=eval_condition,
            evaluate_args=evaluate_args,
        )
        self.metrics = metrics

        if isinstance(data, tuple):
            self.tokens = data[0]
            self.encoded_text = data[1]
        elif isinstance(data, Dataset):
            input_ids = data["input_ids"]
            attention_mask = data["attention_mask"]
            self.encoded_text = BatchEncoding(
                dict(
                    input_ids=(
                        input_ids
                        if isinstance(input_ids, torch.Tensor)
                        else torch.tensor(input_ids)
                    ),
                    attention_mask=(
                        attention_mask
                        if isinstance(attention_mask, torch.Tensor)
                        else torch.tensor(attention_mask)
                    ),
                )
            )
        else:
            raise ValueError(
                "'data' must be a tuple or a Huggingface Dataset "
                "containing token ids and arrays of token string "
                f"representations, but {type(data)} was passed."
            )

        self.step_results = []
        self.sequence_idxs = list(range(self.encoded_text.input_ids.shape[0]))
        self.metrics = {
            name: metric.to(RESULT_DEVICE)
            for name, metric in self.metrics.items()
        }
        self.encoded_text = self.encoded_text.to(inference_device)

    def evaluate(self, model: PreTrainedModel) -> pd.DataFrame:
        """Evaluate all metrics on the given model."""
        produce_value = _produce_metric_args(
            model,
            self.encoded_text,
        )
        metric_results = {}
        for metric_name, metric in self.metrics.items():
            metric_args = {
                arg_name: produce_value(arg_name).to(RESULT_DEVICE)
                for arg_name in metric.required_args
            }
            metric.update(**metric_args)

            result = metric.compute()
            assert isinstance(result, torch.Tensor)
            metric_results[metric_name] = result
            # metric_results[metric_name] = result.tolist()

        for metric in self.metrics.values():
            metric.reset()
        return self._combine_results(metric_results)

    def _combine_results(
        self,
        metric_results: dict[str, torch.Tensor],
    ) -> pd.DataFrame:
        assert len(self.index_names) == 1
        index = pd.Index(
            self.sequence_idxs,
            dtype=int,
            # tupleize_cols=False,
            name=self.index_names[0],
        )
        res = pd.DataFrame(
            {
                metric_name: metric_res.type(torch.float32).tolist()
                for metric_name, metric_res in metric_results.items()
            },
            index=index,
            dtype=np.float32,
        )
        return res


class TokenEvaluationTask(SequenceEvaluationTask):
    """This object is meant to be passed as a callback to the
    Huggingface Trainer class. It is responsible for computing
    evaluation metrics on each evaluation pass for each token
    in each sequence in an evaluation dataset.
    """

    def __init__(
        self,
        metrics: Mapping[str, TokenMetric],
        data: Dataset | tuple[list[list[str]], BatchEncoding],
        index_names: list[str] = ["sequence", "token"],
        inference_device: torch.device = INFERENCE_DEVICE,
    ) -> None:
        super().__init__(
            metrics,
            data,
            index_names=index_names,
            inference_device=inference_device,
        )
        if isinstance(data, tuple):
            self.tokens = data[0]
        elif isinstance(data, Dataset):
            if "tokens" not in data.column_names:
                raise ValueError(
                    "'tokens' must be a column in the dataset containing "
                    "the string representations of the t4295 9066 5716 8302okens."
                )
            self.tokens = data["tokens"]
        else:
            raise ValueError(
                "'data' must be a tuple or a Huggingface Dataset "
                "containing token ids and arrays of token string "
                f"representations, but {type(data)} was passed."
            )

    def _combine_results(
        self,
        metric_results: dict[str, torch.Tensor],
    ) -> pd.DataFrame:
        assert len(self.index_names) == 2
        index = pd.MultiIndex.from_tuples(
            [
                (sequence_idx, token)
                for sequence_idx, sequence_tokens in zip(
                    self.sequence_idxs, self.tokens
                )
                for token in sequence_tokens
            ],
            names=self.index_names,
        )

        attention_mask = self.encoded_text.attention_mask.to(RESULT_DEVICE)
        for seq_tokens, seq_mask in zip(self.tokens, attention_mask):
            assert len(seq_tokens) == seq_mask.sum()

        return pd.DataFrame(
            {
                metric_name: np.concatenate(
                    # Add a nan value for the first token in each string,
                    # because we don't have results for the first token.
                    [
                        np.insert(
                            metric_res_row[attention_mask_row[1:].bool()]
                            .type(torch.float32)
                            .numpy()
                            .flatten(),
                            0,
                            np.nan,
                        )
                        for metric_res_row, attention_mask_row in zip(
                            metric_res,
                            attention_mask,
                        )
                    ],
                    axis=0,
                    # .tolist()
                )
                for metric_name, metric_res in metric_results.items()
            },
            index=index,
            dtype=np.float32,
        )


def _produce_metric_args(
    model: PreTrainedModel,
    encoded_sequences: BatchEncoding,
) -> Callable[[MetricArg], Any]:
    """Produces the values for arguments of a metric.
    The reason for the slightly complicated closure-setup is to enable caching
    of previously computed quantities, esp. inference calls to the model.
    """
    item_values = {}

    def get_metric_arg(item_name: MetricArg) -> Any:
        if item_name not in item_values:
            if item_name == "input":
                item_values["input"] = encoded_sequences
            elif item_name == "model":
                item_values["model"] = model
            elif item_name in [
                "output",
                "logits",
                "token_probs",
                "sequence_probs",
            ]:
                config = PredictionConfig(trim_last_token=True)
                output = predict(model, encoded_sequences, config)
                item_values["output"] = output
                if item_name in ["token_probs", "sequence_probs"]:
                    item_values["token_probs"] = _compute_token_probs(
                        output["logits"],
                        encoded_sequences,
                    )
                elif item_name == "sequence_probs":
                    raise NotImplementedError
                    # item_values[
                    #     "sequence_probs"
                    # ] = compute_sequence_token_probs(
                    #     encoded_sequences,
                    #     item_values["token_probs"],
                    # )
                else:
                    item_values[item_name] = output[item_name]
            else:
                raise ValueError(f"Invalid item name: {item_name}")
        return item_values[item_name]

    return get_metric_arg


def _compute_token_probs(
    logits: torch.Tensor,
    inputs: BatchEncoding,
) -> TokenValues:
    """Convert model logits into probabilities over all tokens.

    Args:
        logits: The model logits.
        inputs: The model inputs as BatchEncoding.

    Returns:
        The probabilities for all tokens.
    """
    attention_mask = inputs.attention_mask
    output_probs = F.log_softmax(logits.float(), dim=-1)
    num_sequence_tokens = attention_mask.sum(dim=-1)
    if isinstance(num_sequence_tokens, torch.Tensor):
        assert torch.all(
            num_sequence_tokens > 1
        ), "All sequences must have at least 2 tokens"
    else:
        assert all(
            n_seq_tokens > 1 for n_seq_tokens in num_sequence_tokens
        ), "All sequences must have at least 2 tokens"

    resized_mask = attention_mask[:, :-1]
    output_probs[resized_mask == 0] = 0
    return TokenValues(
        values=output_probs,
        mask=resized_mask,
    )
