from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import Model

from typing import Dict, List, Tuple

import numpy as np

import torch


class InferCausalLMCalculator(StatCalculator):
    """
    Performs inference of the model and ensures that output contains
    1. logprobas
    2. tokens
    3. embeddings

    For Whitebox model (lm_polygraph.WhiteboxModel), at input texts batch calculates:
    * generation texts
    * tokens of the generation texts
    * probabilities distribution of the generated tokens
    * attention masks across the model (if applicable)
    * embeddings from the model
    """

    def __init__(
        self,
        n_alternatives: int = 10,
        tokenize: bool = False,
        return_embeddings: bool = True,
    ):
        super().__init__()

        self.n_alternatives = n_alternatives
        self._tokenize = tokenize
        self._return_embeddings = return_embeddings

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        """
        Returns the statistics and dependencies for the calculator.
        """

        return [
            "greedy_log_probs",
            "greedy_logits",
            "greedy_tokens",
            "greedy_log_likelihoods",
            "greedy_tokens_alternatives",
            "embeddings_decoder",
        ], []

    def _get_embeddings_from_output(
        self,
        output,
        all_layers: bool = False,
    ):
        batch_embeddings_decoder = None

        if not all_layers:
            hidden_layer = -1
            input_tokens_hs = output.hidden_states[0][hidden_layer].cpu().detach()
            if len(output.hidden_states) > 1:
                generated_tokens_hs = torch.cat(
                    [h[hidden_layer].cpu().detach() for h in output.hidden_states[1:]],
                    dim=1,
                )
        else:
            input_tokens_hs = output.hidden_states[0].mean(axis=0).cpu().detach()
            if len(output.hidden_states) > 1:
                generated_tokens_hs = torch.cat(
                    [h.mean(axis=0).cpu().detach() for h in output.hidden_states[1:]],
                    dim=1,
                )

        if len(output.hidden_states) > 1:
            batch_embeddings_decoder = (
                torch.cat([input_tokens_hs, generated_tokens_hs], dim=1)
                .mean(axis=1)
                .cpu()
                .detach()
            )
        else:
            batch_embeddings_decoder = input_tokens_hs.mean(axis=1).cpu().detach()

        return batch_embeddings_decoder

    def _post_process_logits(self, out, model_inputs, eos_token_id):
        cut_logits = []
        cut_sequences = []
        cut_log_probs = []
        cut_alternatives = []
        lls = []

        all_logits = torch.stack(out.scores, dim=1)
        for i in range(len(model_inputs)):
            seq = out.sequences[i, model_inputs.shape[1] :].cpu()

            length = len(seq)
            for j in range(len(seq)):
                if seq[j] == eos_token_id:
                    length = j + 1
                    break

            tokens = seq[:length].tolist()
            cut_sequences.append(tokens)

            logits = all_logits[i, :length, :].cpu()
            cut_logits.append(logits.numpy())

            log_probs = logits.log_softmax(-1)
            cut_log_probs.append(log_probs.numpy())
            lls.append([log_probs[j, tokens[j]] for j in range(len(log_probs))])

            cut_alternatives.append([[] for _ in range(length)])
            for j in range(length):
                lt = logits[j, :].numpy()
                best_tokens = np.argpartition(lt, -self.n_alternatives)
                ln = len(best_tokens)
                best_tokens = best_tokens[ln - self.n_alternatives : ln]
                for t in best_tokens:
                    cut_alternatives[-1][j].append((t.item(), lt[t].item()))

                cut_alternatives[-1][j].sort(
                    key=lambda x: x[0] == cut_sequences[-1][j],
                    reverse=True,
                )

        result_dict = {
            "greedy_log_probs": cut_log_probs,
            "greedy_logits": cut_logits,
            "greedy_tokens": cut_sequences,
            "greedy_log_likelihoods": lls,
            "greedy_tokens_alternatives": cut_alternatives,
        }

        return result_dict

    def __call__(
        self,
        dependencies: Dict[str, np.array],
        texts: List[str],
        model: Model,
        max_new_tokens: int = 100,  # TODO: move to args_generate
        **kwargs,
    ) -> Dict[str, np.ndarray]:
        """
        Calculates the statistics of probabilities at each token position in the generation.

        Parameters:
            dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used).
            texts (List[str]): Input texts batch used for model generation.
            model (Model): Model used for generation.
            max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100.
        Returns:
            Dict[str, np.ndarray]: dictionary with the following items:
                - 'input_texts' (List[str]): input texts batch,
                - 'input_tokens' (List[List[int]]): tokenized input texts,
                - 'greedy_log_probs' (List[List[np.array]]): logarithms of autoregressive
                        probability distributions at each token,
                - 'greedy_texts' (List[str]): model generations corresponding to the inputs,
                - 'greedy_tokens' (List[List[int]]): tokenized model generations,
                - 'attention' (List[List[np.array]]): attention maps at each token, if applicable to the model,
                - 'greedy_log_likelihoods' (List[List[float]]): log-probabilities of the generated tokens.
        """

        if self._tokenize:
            model_inputs = model.tokenize(texts, padding=True, return_tensors="pt")
        else:
            model_inputs = dependencies["model_inputs"]

        model_inputs = (
            model_inputs
            if isinstance(model_inputs, torch.Tensor)
            else model_inputs["input_ids"]
        )

        args_generate = kwargs.pop("args_generate")
        args_generate.update(
            {
                "return_dict_in_generate": True,
                "output_scores": True,
                "output_hidden_states": True,
            }
        )
        out = model.generate(model_inputs, **args_generate)

        result_dict = self._post_process_logits(
            out, model_inputs, args_generate["generation_config"].eos_token_id
        )

        if self._return_embeddings:
            result_dict.update(
                {"embeddings_decoder": self._get_embeddings_from_output(out)}
            )

        return result_dict
