import json
from itertools import combinations

from itertools import compress
from collections import defaultdict
import queue
import threading
import time
import numpy as np
from scipy.sparse import csr_matrix, vstack
from tqdm import tqdm
from archive.ui.visualizer import VisualizationRunner
from archive.vssm_vector.version_space import Hypothesis
from archive.vssm_vector.version_space import VersionSpace
from logger import VSSMLogger
from normal_distribution import NormalDistribution
from time_monitor import time_monitor
from utils import create_instance_vector, pad_csr_vector, update_attr_to_indices_map


class VSSM:
    def __init__(self, logging_enabled: bool = False, visualize: bool = False):
        # self.av_table = AttributeValueTable()
        self.version_spaces = []
        self.G0 = []
        self.positive_count = 0
        self.negative_buffer = []
        self.need_to_prune = False
        self.p = []
        self.n = []
        # self.vectorizer = DictVectorizer(sparse=True)
        self.av_pairs = defaultdict(str)
        self.attr_to_indices_map = {}
        self.av_counter = 0

        # Timing
        self.time_to_process_pos = 0
        self.time_to_process_neg = 0

        # Model options
        self.logging_enabled = logging_enabled
        self.logger = VSSMLogger(logging_enabled=self.logging_enabled)
        self.visualize = visualize
        self.update_queue = None
        self._visualizer_thread = None

        if self.visualize:
            self._setup_visualizer()

    @time_monitor
    def ifit(self, instance: dict, label: bool):
        self.logger.log(f"Processing instance {instance}, Label: {'Pos' if label else 'Neg'}", 'STEP')
        start = time.time()

        prev_num_attrs = len(self.av_pairs)
        for key in instance:
            if not self.av_pairs[(key, instance[key])]:
                self.av_pairs[(key, instance[key])] = f"{self.av_counter}"
                self.av_counter += 1
        pad_width = len(self.av_pairs) - prev_num_attrs
        self.attr_to_indices_map = update_attr_to_indices_map(self.av_pairs)

        # Pad all current hypotheses with additional features
        for vs in self.version_spaces:
            vs.lengthen_hypothesis_vectors(pad_width=pad_width)

        if not self.G0:
            self.G0.append(Hypothesis(csr_matrix(np.ones(len(self.av_pairs)), dtype=np.int8)))
        else:
            for g in self.G0:
                g.vector = pad_csr_vector(g.vector, pad_width=pad_width, constant_value=1)

        # print(len(self.av_pairs))
        # Create new instance vector
        instance = create_instance_vector(self.av_pairs, instance)

        if label:
            self._handle_positive(Hypothesis(instance))
            self.time_to_process_pos += time.time() - start
        else:
            self._handle_negative(instance)
            self.time_to_process_neg += time.time() - start

    @time_monitor
    def fit(self, instances, labels):
        # print(instances)
        # print(labels)
        count = 0
        for i in tqdm(range(len(instances))):
            #if count == 1000:
            #    break
                # pass
            self.ifit(instances[i], labels[i])
            count += 1
        print(json.dumps(self.get_state(), indent=4))
        time_monitor.report(self.time_to_process_pos, self.time_to_process_neg)

    @time_monitor
    def predict(self, instance):
        # Create new instance vector if needed
        if not isinstance(instance, csr_matrix):
            instance = create_instance_vector(self.av_pairs, instance)
        return any(vs.covers(instance) for vs in self.version_spaces)

    @time_monitor
    def score(self, instances, labels):
        instances = [create_instance_vector(self.av_pairs, i) for i in instances]
        return sum(self.predict(instances[i]) == labels[i] for i in range(len(instances))) / len(instances)

    @time_monitor
    def get_pyHTN_conds(self, identifiers=None) -> list:
        """
        Output the learned concept in Disjunctive Normal Form (dnf).
        Each concept is represented as a list of lists of literals, where each literal
        is of the form (attribute, identifier, value).

        @param identifiers: Optional dict mapping attributes to identifiers
        @return A list representing the DNF of the learned concept
        """
        # Initialize the DNF structure
        dnf = []

        # Process each version space (disjunction)
        for vs_idx, vs in enumerate(self.version_spaces):
            # Process each specific hypothesis in the S set
            for s_idx, s in enumerate(vs.s_set):
                # Create a conjunction of literals
                conjunction = []

                # Process each attribute in the hypothesis
                for attr, value in s.attributes.items():
                    identifier = None if identifiers is None else identifiers.get(attr, None)

                    # Skip wildcard values as they don't add constraints
                    if value == '?':
                        continue

                    # Handle probabilistic distributions for continuous attributes
                    if isinstance(value, NormalDistribution):
                        # For continuous attributes, we use the mean as a representative value
                        # and add both a lower and upper bound based on std deviation
                        if value.pos_count > 0:
                            # Lower bound: mean - std
                            lower_bound = value.pos_mean - value.pos_std
                            conjunction.append((attr, identifier, f">= {lower_bound:.2f}"))

                            # Upper bound: mean + std
                            upper_bound = value.pos_mean + value.pos_std
                            conjunction.append((attr, identifier, f"<= {upper_bound:.2f}"))
                    else:
                        # For categorical attributes, add the exact value constraint
                        conjunction.append((attr, identifier, value))

                # Only add non-empty conjunctions
                if conjunction:
                    dnf.append(conjunction)

        return dnf

    @time_monitor
    def get_lit_priorities(self, identifiers=None) -> list:
        """
        Output a prioritized list of literals, where each literal is of the form
        (attribute, identifier, value) and is assigned a priority in [0.0, 1.0].

        @param identifiers: Optional dict mapping attributes to identifiers
        @return A list of tuples (priority, literal) sorted by priority (highest first)
        """
        # First, we gather all literals and calculate their priorities
        literal_priorities = {}

        # Calculate the total number of specific hypotheses across all version spaces
        total_s_hypotheses = sum(len(vs.s_set) for vs in self.version_spaces)

        # Gather literals from G0 (most general hypotheses) - these get highest priority
        g0_literals = set()
        for g in self.G0:
            for attr, value in g.attributes.items():
                if value != '?':  # Skip wildcards
                    identifier = None if identifiers is None else identifiers.get(attr, None)
                    if isinstance(value, NormalDistribution):
                        # Use mean value for the literal
                        if value.pos_count > 0:
                            literal = (attr, identifier, f"≈ {value.pos_mean:.2f}")
                            g0_literals.add(literal)
                    else:
                        literal = (attr, identifier, value)
                        g0_literals.add(literal)

        # Assign priority 0.9 to G0 literals (high but not 1.0)
        for literal in g0_literals:
            literal_priorities[literal] = 0.9

        # Process version spaces from largest to smallest (by S set size)
        sorted_vs = sorted(enumerate(self.version_spaces), key=lambda x: len(x[1].s_set), reverse=True)

        for vs_idx, vs in sorted_vs:
            # Calculate version space weight based on its relative size
            vs_weight = len(vs.s_set) / total_s_hypotheses if total_s_hypotheses > 0 else 0

            # Process G set (more general hypotheses)
            for g_idx, g in enumerate(vs.g_set):
                for attr, value in g.attributes.items():
                    if value != '?':  # Skip wildcards
                        identifier = None if identifiers is None else identifiers.get(attr, None)
                        if isinstance(value, NormalDistribution):
                            # Use mean value for the literal
                            if value.pos_count > 0:
                                literal = (attr, identifier, f"≈ {value.pos_mean:.2f}")
                                # G set literals get higher priority (0.7-0.9) * version space weight
                                priority = 0.7 + (0.2 * vs_weight)
                                literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)
                        else:
                            literal = (attr, identifier, value)
                            # G set literals get higher priority (0.7-0.9) * version space weight
                            priority = 0.7 + (0.2 * vs_weight)
                            literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)

            # Process S set (more specific hypotheses)
            for s_idx, s in enumerate(vs.s_set):
                # Calculate specificity factor - earlier (more specific) hypotheses get higher priority
                specificity = 1.0 - (s_idx / len(vs.s_set)) if len(vs.s_set) > 1 else 1.0

                for attr, value in s.attributes.items():
                    if value != '?':  # Skip wildcards
                        identifier = None if identifiers is None else identifiers.get(attr, None)
                        if isinstance(value, NormalDistribution):
                            # Use mean value for the literal
                            if value.pos_count > 0:
                                literal = (attr, identifier, f"≈ {value.pos_mean:.2f}")
                                # S set literals get lower priority (0.4-0.7) * specificity * version space weight
                                priority = 0.4 + (0.3 * specificity * vs_weight)
                                literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)
                        else:
                            literal = (attr, identifier, value)
                            # S set literals get lower priority (0.4-0.7) * specificity * version space weight
                            priority = 0.4 + (0.3 * specificity * vs_weight)
                            literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)

        # For continuous attributes, also add literals based on attribute distributions
        for attr, dist in self.attribute_distributions.items():
            if dist.pos_count > 0 and dist.neg_count > 0:
                identifier = None if identifiers is None else identifiers.get(attr, None)

                # Calculate discrimination power as a measure of importance
                separation = abs(dist.pos_mean - dist.neg_mean) / (dist.pos_std + dist.neg_std) if (
                                                                                                           dist.pos_std + dist.neg_std) > 0 else 1.0
                separation = min(1.0, separation)  # Cap at 1.0

                # More discriminative attributes get higher priority
                priority = 0.5 + (0.5 * separation)

                # Add mean value literal
                mean_literal = (attr, identifier, f"≈ {dist.pos_mean:.2f}")
                literal_priorities[mean_literal] = max(literal_priorities.get(mean_literal, 0), priority)

                # Add range literals with slightly lower priority
                lower_literal = (attr, identifier, f">= {(dist.pos_mean - dist.pos_std):.2f}")
                upper_literal = (attr, identifier, f"<= {(dist.pos_mean + dist.pos_std):.2f}")

                literal_priorities[lower_literal] = max(literal_priorities.get(lower_literal, 0), priority * 0.9)
                literal_priorities[upper_literal] = max(literal_priorities.get(upper_literal, 0), priority * 0.9)

        # Convert to list of tuples and sort by priority (highest first)
        prioritized_literals = [(priority, literal) for literal, priority in literal_priorities.items()]
        prioritized_literals.sort(reverse=True)

        return prioritized_literals

    @time_monitor
    def get_state(self, current_instance=None, label=None):
        """Returns the current state of the algorithm in a serializable format."""
        state = {
            "g0": [str(g) for g in self.G0],
            "version_spaces": [
                {
                    "s_set": [str(s) for s in vs.S],
                    "g_set": [str(g) for g in vs.G]
                }
                for vs in self.version_spaces
            ],
            "current_instance": str(current_instance) if current_instance else "N/A",
            "label": 'Positive' if label is True else 'Negative' if label is False else "N/A",
            "num_version_spaces": len(self.version_spaces),
            "num_g0_hypotheses": len(self.G0),
            "avg_version_space_S_size": sum(len(vs.S) for vs in self.version_spaces)/len(self.version_spaces),
            "avg_version_space_G_size": sum(len(vs.G) for vs in self.version_spaces) / len(self.version_spaces),
        }
        return state

    @time_monitor
    def _minimally_specialize_hypothesis(self, h: Hypothesis, instance: csr_matrix):
        """
        Generates minimal specializations of a hypothesis to exclude an instance.
        """
        if not h.covers(instance):
            self.logger.log("Hypothesis already excludes instance.", 'INFO')
            return [h]

        self.logger.log("Specializing hypothesis to exclude instance.", 'ACTION')

        # Get the column indices of the features present in the instance.
        instance_feature_indices = set(instance.indices)

        # Minimal specializations iteratively increase the number specialized attributes until a set of specializations
        # is found. It begins by specializing one attribute at a time. Then, if none are found, it specializes every
        # combo of two attributes, then three, and so on.
        specializations = []
        for i in range(1, len(self.attr_to_indices_map) + 1):
            # Generate non-redundant combinations of every attribute starting with 1 at a time, then 2, and so on
            attr_combos = combinations(self.attr_to_indices_map.keys(), i)

            for combo in attr_combos:
                # Create a copy of the hypothesis vector to modify.
                new_h_lil = h.vector.tolil()
                is_valid_specialization = True

                for attr in combo:
                    # Get the list of all indices for this attribute
                    attr_indices = self.attr_to_indices_map[attr]
                    # Find which of this attribute's indices is in the instance.
                    # The intersection will contain exactly one index (because only one value can be set for an
                    # attribute at a time
                    indices_in_instance = [idx for idx in attr_indices if idx in instance_feature_indices]

                    if not indices_in_instance:
                        # This should not happen if the instance is well-formed, including anyway
                        continue

                    idx_to_specialize = indices_in_instance[0]
                    # The core specialization step: turn off the specific bit.
                    new_h_lil[0, idx_to_specialize] = 0

                    # Ensure the hypothesis for this attribute is not now empty.
                    # Check this by seeing if ANY of the bits for this attribute are still on.
                    # if new_h_lil[0, attr_indices].nnz == 0:
                    #    is_valid_specialization = False
                    #    break  # This attribute combo is invalid, move to the next.

                if is_valid_specialization:
                    new_h_vector = new_h_lil.tocsr()
                    specializations.append(Hypothesis(new_h_vector))

            if specializations:
                self.logger.log(f"Generated {len(specializations)} new minimal specializations: {[h.vector.toarray() for h in specializations]}", 'DEBUG')
                return specializations

        return []

    @time_monitor
    def _prune_less_general(self, hyp_set: list[Hypothesis]) -> list[Hypothesis]:
        """
        Prunes a list of hypotheses using a vectorized matrix multiplication.

        @param: hyp_set: A list of Hypothesis objects.
        """
        num_hypotheses = len(hyp_set)
        if num_hypotheses < 2:
            return hyp_set

        # 1. Stack all hypothesis vectors into one large CSR matrix.
        H_stack = vstack([h.vector for h in hyp_set], format='csr')

        # 2. Compute the N x N intersection matrix. The (i, j) element is the
        #    number of shared features between hypothesis i and hypothesis j.
        #    This is the most computationally intensive step, but it's a single,
        #    highly optimized BLAS operation.
        intersection_matrix = (H_stack @ H_stack.T).toarray()

        # 3. Get the number of features (non-zero elements) in each hypothesis.
        nnz_vector = H_stack.getnnz(axis=1)

        # 4. Create the N x N boolean GeneralityMatrix.
        #    h_i is more general than h_j if their intersection count is equal
        #    to the total number of features in h_j.
        #    Broadcasting compares each column of intersection_matrix with nnz_vector.
        generality_matrix = (intersection_matrix == nnz_vector)

        # 5. A hypothesis is subsumed if another hypothesis is strictly more general.
        #    First, set the diagonal to False, because a hypothesis is not
        #    *strictly* more general than itself.
        np.fill_diagonal(generality_matrix, False)

        #    A hypothesis j is subsumed if ANY value in its column is True.
        #    any(axis=0) checks for any True values down each column.
        is_subsumed_mask = generality_matrix.any(axis=0)

        # 6. Filter the original list to keep only the non-subsumed hypotheses.
        #    We use itertools.compress for a clean and efficient way to filter.
        keep_mask = ~is_subsumed_mask
        pruned_set = list(compress(hyp_set, keep_mask))

        return pruned_set

    @time_monitor
    def _handle_positive(self, vector_hyp: Hypothesis) -> None:

        for i, vs in enumerate(self.version_spaces):
            # Check if the hypothesis is already covered by an existing S-set. If so, do nothing.
            for j, s in enumerate(vs.S):
                self.logger.log(f"Checking if VS-{i + 1}|S-{j + 1} already covers instance...", 'INFO')

                if s.more_general_than(vector_hyp):
                    self.logger.log(f"VS-{i + 1}| S-{j + 1} already covers instance. Doing nothing.", 'SUCCESS')
                    return

                # Check if a version space exists whose S set after including I is still consistent with its corresponding
                # G set
                self.logger.log(f"Checking if VS-{i + 1}| S-{j + 1} can be generalized...", 'INFO')
                self.logger.log(f"Generalizing {s} and {vector_hyp}...", 'INFO')
                candidate_hyp = Hypothesis.generalize(s, vector_hyp)
                self.logger.log(f"Generalized vector: {candidate_hyp}", 'INFO')
                if vs.can_include(candidate_hyp):
                    self.logger.log(f"VS-{i + 1} can be generalized. Attempting merge.", 'SUCCESS')
                    vs.S.remove(s)
                    vs.S.append(candidate_hyp)
                    self.logger.log(f"Successfully merged hypothesis into VS-{i + 1}.", 'SUCCESS')
                    return

        if not self.version_spaces:
            self.logger.log("There are no existing version spaces.", 'ACTION')

        self.logger.log("No existing VS could be generalized. Creating a new version space.", 'ACTION')

        self._create_new_version_space(vector_hyp)


    @time_monitor
    def _create_new_version_space(self, vector_hyp: Hypothesis):
        new_s = [Hypothesis(vector_hyp.vector)]
        new_g = [Hypothesis(g.vector) for g in self.G0 if g.more_general_than(vector_hyp)]
        if new_g:
            new_vs = VersionSpace(new_s, new_g, self.logging_enabled)
            self.version_spaces.append(new_vs)
            self.logger.log(f"Created new Version Space: {new_vs}", 'SUCCESS')
        else:
            self.logger.log("Cannot create a new VS: no hypothesis in G0 is general enough.", 'FAIL')


    @time_monitor
    def _handle_negative(self, hyp_vector: csr_matrix) -> None:
        self.logger.log(f"Handling negative instance: {hyp_vector.toarray()}", 'ACTION')

        self.logger.log(f"Specializing global G0 set consisting of {[h for h in self.G0]}.", 'ACTION')

        # new_G0 = set().union(*(self._minimally_specialize_hypothesis(g, instance) for g in self.G0))
        new_G0 = [i for g in self.G0 for i in self._minimally_specialize_hypothesis(g, hyp_vector)]
        #self.G0 = new_G0
        self.G0 = self._prune_less_general(new_G0)
        self.logger.log(f"Updated G0: {self.G0}", 'SUCCESS')

        inconsistent_vss = []
        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"Specializing G-set of VS-{i + 1}.", 'ACTION')
            # new_G = set().union(*(self._minimally_specialize_hypothesis(g, instance) for g in vs.G))
            new_G = [i for g in vs.G for i in self._minimally_specialize_hypothesis(g, hyp_vector)]
            #vs.G = new_G
            vs.G = self._prune_less_general(new_G)
            self.logger.log(f"Updated G-set for VS-{i + 1}: {vs.G}", 'DEBUG')
            self.logger.log(f"Checking consistency of VS-{i + 1}...", 'INFO')
            if not vs.is_consistent():
                self.logger.log(f"VS-{i + 1} is now INCONSISTENT. Marking for split.", 'FAIL')
                inconsistent_vss.append(vs)

        if inconsistent_vss:
            self.logger.log(f"Splitting {len(inconsistent_vss)} inconsistent version space(s).", 'ACTION')
        for vs in inconsistent_vss:
            self.logger.log(f"Processing split for inconsistent VS: {vs}", 'INFO')
            self.version_spaces.remove(vs)
            # for disjunct in vs.S:
            #for disjunct in set().union(*(self._minimally_specialize_hypothesis(s, instance) for s in vs.S)):
            for disjunct in [i for s in vs.S for i in self._minimally_specialize_hypothesis(s, hyp_vector)]:
                self.logger.log(f"Re-processing disjunct {disjunct} from split VS.", 'ACTION')
                # self._handle_positive(disjunct)
                self._create_new_version_space(disjunct)

    @time_monitor
    def _setup_visualizer(self):
        """Initializes and starts the visualization servers in a background thread."""
        self.logger.log("Visualization enabled. Setting up servers...", "STEP")
        self.update_queue = queue.Queue()

        # Create the runner instance
        runner = VisualizationRunner(self.update_queue)

        # Start the HTTP and WebSocket servers
        runner.start_servers()

        # Start the broadcast loop in its own thread. This thread will pull
        # items from the queue and pass them to the server's event loop.
        self._visualizer_thread = threading.Thread(
            target=runner.broadcast_updates,
            daemon=True
        )
        self._visualizer_thread.start()
        self.logger.log("Visualization setup complete.", "SUCCESS")

    @time_monitor
    def close_visualizer(self):
        """Signals the visualization thread to shut down."""
        if self.visualize and self.update_queue:
            self.logger.log("Sending shutdown signal to visualizer.", "ACTION")
            self.update_queue.put('SHUTDOWN')
            if self._visualizer_thread:
                self._visualizer_thread.join(timeout=2)