from random import choice
from typing import Dict, Set, Any, Union
import numpy as np
from scipy.sparse import csr_matrix, vstack
from normal_distribution import NormalDistribution
from logger import VSSMLogger
from time_monitor import time_monitor
from utils import pad_csr_vector



class Hypothesis:
    """
    Represents a hypothesis in the version space. '?' represents any value and NormalDistribution represents
    continuous attributes.
    """

    def __init__(self, vector: csr_matrix):
        self.vector = vector

    @time_monitor
    def covers(self, instance_vector: csr_matrix) -> bool:
        """
        Check if the hypothesis covers (matches) the given instance.

        @param instance_vector: An instance to check
        @return True if the hypothesis covers the instance, False otherwise
        """
        return self._is_more_general_sparse(self.vector, instance_vector)

    @time_monitor
    def more_general_than(self, other: 'Hypothesis') -> bool:
        return self._is_more_general_sparse(self.vector, other.vector)

    @staticmethod
    @time_monitor
    def _is_more_general_sparse(v1: csr_matrix, v2: csr_matrix) -> bool:
        """
        Checks if v1 is more general than v2.
        This works by checking if (v1 - v2) has any negative elements.
        """
        '''
        difference = v1 - v2

        # If the smallest value in the resulting sparse matrix is
        # greater than or equal to 0, no counter-example was found.
        # Need this section to handle edge case of identical vectors (calling .min() would throw error)
        if difference.nnz == 0:
            return True

        return difference.min() >= 0
        '''
        if v2.nnz > v1.nnz: return False
        return set(v2.indices).issubset(set(v1.indices))

    @staticmethod
    @time_monitor
    def generalize(h1: Union['Hypothesis', csr_matrix], h2: Union['Hypothesis', csr_matrix]) -> 'Hypothesis':
        """
        Minimally generalize a hypothesis to cover the given instance.

        This method computes the minimal generalization of two hypotheses (or a
        hypothesis and an instance). If attribute values differ AND the union of their values
        represents all observed values for that attribute, they are replaced by a wildcard.

        @param h1: The first hypothesis (or an instance dictionary).
        @param h2: The second hypothesis (or an instance dictionary).
        @return A new, generalized Hypothesis object.
        """

        v1 = h1.vector if isinstance(h1, Hypothesis) else h1
        v2 = h2.vector if isinstance(h2, Hypothesis) else h2
        v_sum = v1 + v2
        v_sum.data.clip(max=1, out=v_sum.data)
        return Hypothesis(v_sum)


    def __eq__(self, other: 'Hypothesis') -> bool:
        if not isinstance(other, Hypothesis): return False
        return (self.vector - other.vector).nnz == 0

    def __hash__(self) -> int:
        return hash(tuple(self.vector))

    def __str__(self) -> str:
        return str(self.vector.toarray())

    def __repr__(self) -> str:
        return f"Hypothesis({self.vector.toarray().tolist()[0]})"

    def __json__(self):
        return {
            "attributes": {}
        }


class VersionSpace:
    """
    Represents a version space with S and G boundary sets.
    """

    def __init__(self, s_set: list[Hypothesis], g_set: list[Hypothesis], logging_enabled: bool = False):
        self.S = s_set
        self.G = g_set

        self.logger = VSSMLogger(logging_enabled)

    @time_monitor
    def is_consistent(self) -> bool:
        """
        Check if the version space is consistent.

        @return True if consistent, False otherwise
        """
        if not self.S or not self.G:
            self.logger.log("VS is inconsistent because S or G set is empty.", "FAIL")
            return False
        for s in self.S:
            if not any(g.more_general_than(s) for g in self.G):
                self.logger.log(f"S-hypothesis {s} is not covered by any G-hypothesis.", "FAIL")
                return False
        self.logger.log("Version space remains consistent.", "SUCCESS")
        return True

    @time_monitor
    def covers(self, instance: csr_matrix) -> bool:
        """
        Check if an instance is covered by this version space. 'Covered' means that some specific set is more general
        than the instance.

        @param instance: A hypothesis to check
        @return True if the hypothesis can be included, False otherwise
        """
        for s in self.S:
            if s.covers(instance):
                return True
        return False

    @time_monitor
    def can_include(self, h: Hypothesis) -> bool:
        """
        Check if the hypothesis can be included in this version space.

        @param h: A hypothesis to check
        @return True if the hypothesis can be included, False otherwise
        """
        '''
        if not self.G:
            return False

        for g in self.G:
            if g.more_general_than(h):
                return True
        return False
        '''
        """
        Checks if a hypothesis can be included by comparing against the G-set
        using a vectorized matrix operation.
        """
        if not self.G:
            return False

        g_stack = vstack([g.vector for g in self.G], format='csr')
        h_vector = h.vector

        intersection_counts = (g_stack @ h_vector.T).toarray().flatten()
        generality_mask = (intersection_counts == h_vector.nnz)

        return generality_mask.any()

    @time_monitor
    def lengthen_hypothesis_vectors(self, pad_width: int):
        for s in self.S:
            s.vector = pad_csr_vector(s.vector, pad_width, constant_value=0)
        for g in self.G:
            g.vector = pad_csr_vector(g.vector, pad_width, constant_value=1)


    def __str__(self) -> str:
        return f"VersionSpace(S={self.S}, G={self.G})"

    def __repr__(self) -> str:
        return f"VersionSpace(s_set={repr(self.S)}, g_set={repr(self.G)})"

