import numpy as np

from collections import defaultdict
from typing import Dict, List, Tuple

from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import WhiteboxModel


class SemanticClassesCalculatorBase(StatCalculator):
    def __init__(self, sample_source: str):
        super().__init__()
        self.sample_source = sample_source

    def __call__(
            self,
            dependencies: Dict[str, np.array],
            texts: List[str],
            model: WhiteboxModel,
            max_new_tokens: int = 100,
    ) -> Dict[str, np.ndarray]:
        self._is_entailment = (
                dependencies[f"{self.sample_source}_semantic_matrix_classes"] == dependencies["entailment_id"]
        )
        self.get_classes(dependencies[f"{self.sample_source}_texts"])

        return {
            f"{self.sample_source}_semantic_classes_entail": {
                "sample_to_class": self._sample_to_class,
                "class_to_sample": self._class_to_sample,
            }
        }

    def get_classes(self, hyps_list: List[List[str]]):
        self._sample_to_class = {}
        self._class_to_sample: Dict[int, List] = defaultdict(list)

        [
            self._determine_class(idx, i)
            for idx, hyp in enumerate(hyps_list)
            for i in range(len(hyp))
        ]

        return self._sample_to_class, self._class_to_sample

    def _determine_class(self, idx: int, i: int):
        # For first hypo just create a zeroth class
        if i == 0:
            self._class_to_sample[idx] = [[0]]
            self._sample_to_class[idx] = {0: 0}

            return 0

        # Iterate over existing classes and return if hypo belongs to one of them
        for class_id in range(len(self._class_to_sample[idx])):
            class_text_id = self._class_to_sample[idx][class_id][0]
            forward_entailment = self._is_entailment[idx, class_text_id, i]
            backward_entailment = self._is_entailment[idx, i, class_text_id]
            if forward_entailment and backward_entailment:
                self._class_to_sample[idx][class_id].append(i)
                self._sample_to_class[idx][i] = class_id

                return class_id

        # If none of the existing classes satisfy - create new one
        new_class_id = len(self._class_to_sample[idx])
        self._sample_to_class[idx][i] = new_class_id
        self._class_to_sample[idx].append([i])

        return new_class_id


class SemanticClassesCalculator(SemanticClassesCalculatorBase):
    def __init__(self, ):
        super().__init__("sample")

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        """
        Returns the statistics and dependencies for the calculator.
        """

        return [
            "sample_semantic_classes_entail",
        ], [
            "sample_texts",
            "sample_semantic_matrix_entail",
            "sample_semantic_matrix_classes",
            "entailment_id",
        ]


class BeamSemanticClassesCalculator(SemanticClassesCalculatorBase):
    def __init__(self, ):
        super().__init__("beamsearch")

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        """
        Returns the statistics and dependencies for the calculator.
        """

        return [
            "beamsearch_semantic_classes_entail",
        ], [
            "beamsearch_texts",
            "beamsearch_semantic_matrix_entail",
            "beamsearch_semantic_matrix_classes",
            "entailment_id",
        ]
