import numpy as np
import logging
from typing import Dict
from lm_polygraph.estimators.estimator import Estimator
import torch.nn as nn

from lm_polygraph.stat_calculators.step.utils import flatten, reconstruct

log = logging.getLogger(__name__)

softmax = nn.Softmax(dim=1)


class StepsNumSemSets(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Number of Semantic Sets" as provided in the paper https://arxiv.org/abs/2305.19187.
    Works with both whitebox and blackbox models (initialized using
    lm_polygraph.utils.model.BlackboxModel/WhiteboxModel).
    """

    def __init__(
            self,
            verbose: bool = False,
    ):
        super().__init__(
            [
                "steps_semantic_matrix_entail",
                "steps_semantic_matrix_contra",
                "sample_steps_texts",
            ],
            "claim",
        )
        self.verbose = verbose

    def __str__(self):
        return "StepsNumSemSets"

    def find_connected_components(self, graph):
        def dfs(node, component):
            visited[node] = True
            component.append(node)

            for neighbor in graph[node]:
                if not visited[neighbor]:
                    dfs(neighbor, component)

        visited = [False] * len(graph)
        components = []

        for i in range(len(graph)):
            if not visited[i]:
                component = []
                dfs(i, component)
                components.append(component)

        return components

    def U_NumSemSets(self, semantic_matrix_entail, semantic_matrix_contra):
        # We have forward upper triangular and backward in lower triangular
        # parts of the semantic matrices
        W_entail = semantic_matrix_entail[:, :]
        W_contra = semantic_matrix_contra[:, :]

        # We check that for every pairing (both forward and backward)
        # the condition satisfies
        W = (W_entail > W_contra).astype(int)
        # Multiply by it's transpose to get symmetric matrix of full condition
        W = W * np.transpose(W)
        # Take upper triangular part with no diag
        W = np.triu(W, k=1)

        a = [[i] for i in range(W.shape[0])]

        # Iterate through each row in 'W' and update the corresponding row in 'a'
        for i, row in enumerate(W):
            # Find the indices of non-zero elements in the current row
            non_zero_indices = np.where(row != 0)[0]

            # Append the non-zero indices to the corresponding row in 'a'
            a[i].extend(non_zero_indices.tolist())

        # Create an adjacency list representation of the graph
        graph = [[] for _ in range(len(a))]
        for sublist in a:
            for i in range(len(sublist) - 1):
                graph[sublist[i]].append(sublist[i + 1])
                graph[sublist[i + 1]].append(sublist[i])

        # Find the connected components
        connected_components = self.find_connected_components(graph)

        # Calculate the number of connected components
        # Cast to float for consistency with other estimators
        num_components = float(len(connected_components))

        return num_components

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the number of semantic sets for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * matrix with corresponding semantic similarities in
                  'semantic_matrix_entail' and 'semantic_matrix_contra'
        Returns:
            np.ndarray: number of semantic sets for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        res = []
        sample_steps_texts = flatten(stats["sample_steps_texts"])
        steps_semantic_matrix_entail = flatten(stats["steps_semantic_matrix_entail"])
        steps_semantic_matrix_contra = flatten(stats["steps_semantic_matrix_contra"])
        for i in range(len(sample_steps_texts)):
            res.append(self.U_NumSemSets(steps_semantic_matrix_entail[i], steps_semantic_matrix_contra[i]))
        return reconstruct(res, stats["sample_steps_texts"])
