import numpy as np
import logging
from typing import Dict
from .estimator import Estimator
import torch.nn as nn

from .process_probs import process_probs

log = logging.getLogger(__name__)

softmax = nn.Softmax(dim=1)


class NumSemSets(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Number of Semantic Sets" as provided in the paper https://arxiv.org/abs/2305.19187.
    Works with both whitebox and blackbox models (initialized using
    lm_polygraph.utils.model.BlackboxModel/WhiteboxModel).
    """

    def __init__(
            self,
            verbose: bool = False,
            samples_source: str = "sample",
    ):
        super().__init__(
            [
                f"{samples_source}_semantic_matrix_entail",
                f"{samples_source}_semantic_matrix_contra",
                f"{samples_source}_texts",
            ],
            "sequence",
        )
        self.verbose = verbose
        self.samples_source = samples_source

    def __str__(self):
        base = "NumSemSets"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def find_connected_components(self, graph):
        def dfs(node, component):
            visited[node] = True
            component.append(node)

            for neighbor in graph[node]:
                if not visited[neighbor]:
                    dfs(neighbor, component)

        visited = [False] * len(graph)
        components = []

        for i in range(len(graph)):
            if not visited[i]:
                component = []
                dfs(i, component)
                components.append(component)

        return components

    def U_NumSemSets(self, i, stats):
        # We have forward upper triangular and backward in lower triangular
        # parts of the semantic matrices
        W_entail = np.array(stats[f"{self.samples_source}_semantic_matrix_entail"])[i, :, :]
        W_contra = np.array(stats[f"{self.samples_source}_semantic_matrix_contra"])[i, :, :]

        # We check that for every pairing (both forward and backward)
        # the condition satisfies
        W = (W_entail > W_contra).astype(int)
        # Multiply by it's transpose to get symmetric matrix of full condition
        W = W * np.transpose(W)
        # Take upper triangular part with no diag
        W = np.triu(W, k=1)

        a = [[j] for j in range(W.shape[0])]

        # Iterate through each row in 'W' and update the corresponding row in 'a'
        for j, row in enumerate(W):
            # Find the indices of non-zero elements in the current row
            non_zero_indices = np.where(row != 0)[0]

            # Append the non-zero indices to the corresponding row in 'a'
            a[j].extend(non_zero_indices.tolist())

        # Create an adjacency list representation of the graph
        graph = [[] for _ in range(len(a))]
        for sublist in a:
            for j in range(len(sublist) - 1):
                graph[sublist[j]].append(sublist[j + 1])
                graph[sublist[j + 1]].append(sublist[j])

        # Find the connected components
        connected_components = self.find_connected_components(graph)

        # Calculate the number of connected components
        # Cast to float for consistency with other estimators
        num_components = float(len(connected_components))

        return num_components

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the number of semantic sets for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * matrix with corresponding semantic similarities in
                  'sample_semantic_matrix_entail' and 'sample_semantic_matrix_contra'
        Returns:
            np.ndarray: number of semantic sets for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        res = []
        for i, answers in enumerate(stats[f"{self.samples_source}_texts"]):
            if self.verbose:
                log.debug(f"generated answers: {answers}")
            res.append(self.U_NumSemSets(i, stats))

        return np.array(res)


class NumSemSetsP(Estimator):
    def __init__(
            self,
            verbose: bool = False,
            samples_source: str = "beamsearch",
            **process_probs_args,
    ):
        super().__init__(
            [
                f"{samples_source}_log_likelihoods",
                f"{samples_source}_semantic_matrix_entail",
                f"{samples_source}_semantic_matrix_contra",
                f"{samples_source}_texts",
            ],
            "sequence",
        )
        self.verbose = verbose
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "NumSemSetsP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def find_connected_components(self, graph):
        def dfs(node, component):
            visited[node] = True
            component.append(node)

            for neighbor in graph[node]:
                if not visited[neighbor]:
                    dfs(neighbor, component)

        visited = [False] * len(graph)
        components = []

        for i in range(len(graph)):
            if not visited[i]:
                component = []
                dfs(i, component)
                components.append(component)

        return components

    def U_NumSemSets(self, i, stats):
        # We have forward upper triangular and backward in lower triangular
        # parts of the semantic matrices
        W_entail = np.array(stats[f"{self.samples_source}_semantic_matrix_entail"])[i, :, :]
        W_contra = np.array(stats[f"{self.samples_source}_semantic_matrix_contra"])[i, :, :]

        # We check that for every pairing (both forward and backward)
        # the condition satisfies
        W = (W_entail > W_contra).astype(int)
        # Multiply by it's transpose to get symmetric matrix of full condition
        W = W * np.transpose(W)
        # Take upper triangular part with no diag
        W = np.triu(W, k=1)

        a = [[j] for j in range(W.shape[0])]

        # Iterate through each row in 'W' and update the corresponding row in 'a'
        for j, row in enumerate(W):
            # Find the indices of non-zero elements in the current row
            non_zero_indices = np.where(row != 0)[0]

            # Append the non-zero indices to the corresponding row in 'a'
            a[j].extend(non_zero_indices.tolist())

        # Create an adjacency list representation of the graph
        graph = [[] for _ in range(len(a))]
        for sublist in a:
            for j in range(len(sublist) - 1):
                graph[sublist[j]].append(sublist[j + 1])
                graph[sublist[j + 1]].append(sublist[j])

        # Find the connected components
        connected_components = self.find_connected_components(graph)

        # Calculate the number of connected components
        # Cast to float for consistency with other estimators

        sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"][i]
        probs = np.array([np.exp(sum(s)) for s in sample_token_lls])
        probs = process_probs(probs, **self.process_probs_args)

        expected_n_components = 0
        for c in connected_components:
            c_prob = np.sum([probs[x] for x in c])
            p_comp_in_multinomial = 1 - (1 - c_prob) ** (W.shape[0])
            expected_n_components += p_comp_in_multinomial

        return expected_n_components

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the number of semantic sets for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * matrix with corresponding semantic similarities in
                  'sample_semantic_matrix_entail' and 'sample_semantic_matrix_contra'
        Returns:
            np.ndarray: number of semantic sets for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        res = []
        for i, answers in enumerate(stats[f"{self.samples_source}_texts"]):
            if self.verbose:
                log.debug(f"generated answers: {answers}")
            res.append(self.U_NumSemSets(i, stats))

        return np.array(res)
