import numpy as np
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu

from typing import Dict

from .estimator import Estimator

from absl import logging as absl_logging

from .process_probs import process_probs

# This prevents bullshit spam from rouge scorer
absl_logging.set_verbosity(absl_logging.WARNING)


class LexicalSimilarity(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Lexical Similarity" as provided in the paper https://arxiv.org/abs/2302.09664.
    Works with both whitebox and blackbox models (initialized using
    lm_polygraph.utils.model.BlackboxModel/WhiteboxModel).

    The method calculates mean similarity between all pairs of sampled generations with minus sign.
    The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
    'samples_n' parameter.
    """

    def __init__(
            self,
            metric: str = "rougeL",
            samples_source: str = "sample",
    ):
        """
        Parameters:
            metric (str): similarity metric (default: 'rougeL'). Possible values:
                * rouge1 / rouge2 / rougeL
                * BLEU
        """
        self.metric = metric
        if self.metric.startswith("rouge"):
            self.scorer = rouge_scorer.RougeScorer([self.metric], use_stemmer=True)
        super().__init__([f"{samples_source}_texts"], "sequence")
        self.samples_source = samples_source

    def __str__(self):
        base = f"LexicalSimilarity_{self.metric}"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def _score_single(self, t1: str, t2: str):
        if self.metric.startswith("rouge"):
            return self.scorer.score(t1, t2)[self.metric].fmeasure
        elif self.metric == "BLEU":
            min_sentence_len = min(len(t1.split()), len(t2.split()))
            if min_sentence_len == 1:
                weights = [1.0, 0.0, 0.0, 0.0]
            elif min_sentence_len == 2:
                weights = [0.5, 0.5, 0.0, 0.0]
            elif min_sentence_len == 3:
                weights = [0.33, 0.33, 0.33, 0.0]
            else:
                # default weights in sentence_bleu
                weights = [0.25, 0.25, 0.25, 0.25]
            return sentence_bleu([t1.split()], t2.split(), weights=weights)
        else:
            raise Exception(f"Unknown metrics for lexical similarity: {self.metric}")

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the mean similarity with minus sign for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * several sampled texts in 'sample_texts'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        batch_texts = stats[f"{self.samples_source}_texts"]
        res = []
        for texts in batch_texts:
            sims = []
            for i in range(len(texts)):
                for j in range(i + 1, len(texts)):
                    sims.append(self._score_single(texts[i], texts[j]))
            res.append(-np.mean(sims))
        return np.array(res)


class LexicalSimilarityP(Estimator):
    def __init__(
            self,
            metric: str = "rougeL",
            samples_source: str = "beamsearch",
            **process_probs_args,
    ):
        """
        Parameters:
            metric (str): similarity metric (default: 'rougeL'). Possible values:
                * rouge1 / rouge2 / rougeL
                * BLEU
        """
        self.metric = metric
        if self.metric.startswith("rouge"):
            self.scorer = rouge_scorer.RougeScorer([self.metric], use_stemmer=True)
        super().__init__([f"{samples_source}_texts", f"{samples_source}_log_likelihoods"], "sequence")
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = f"LexicalSimilarityP_{self.metric}"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def _score_single(self, t1: str, t2: str):
        if self.metric.startswith("rouge"):
            return self.scorer.score(t1, t2)[self.metric].fmeasure
        elif self.metric == "BLEU":
            min_sentence_len = min(len(t1.split()), len(t2.split()))
            if min_sentence_len == 1:
                weights = [1.0, 0.0, 0.0, 0.0]
            elif min_sentence_len == 2:
                weights = [0.5, 0.5, 0.0, 0.0]
            elif min_sentence_len == 3:
                weights = [0.33, 0.33, 0.33, 0.0]
            else:
                # default weights in sentence_bleu
                weights = [0.25, 0.25, 0.25, 0.25]
            return sentence_bleu([t1.split()], t2.split(), weights=weights)
        else:
            raise Exception(f"Unknown metrics for lexical similarity: {self.metric}")

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the mean similarity with minus sign for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * several sampled texts in 'sample_texts'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        batch_sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"]
        batch_texts = stats[f"{self.samples_source}_texts"]
        res = []
        for sample_token_lls, texts in zip(batch_sample_token_lls, batch_texts):
            sample_lls = np.array([np.exp(sum(x)) for x in sample_token_lls])
            sample_lls = process_probs(sample_lls, **self.process_probs_args)

            sims = []
            for i in range(len(texts)):
                for j in range(i + 1, len(texts)):
                    sims.append(self._score_single(texts[i], texts[j]) * sample_lls[i] * sample_lls[j])
            res.append(-np.mean(sims))
        return np.array(res)
