import json
from itertools import combinations
from collections import defaultdict
import queue
import threading
import time
import numpy as np
from scipy.sparse import csr_matrix
from tqdm import tqdm
from archive.ui.visualizer import VisualizationRunner
from attribute_value_table import AttributeValueTable
from version_space import Hypothesis
from version_space import VersionSpace
from logger import VSSMLogger
from normal_distribution import NormalDistribution
from time_monitor import time_monitor


class VSSM:
    def __init__(self, logging_enabled: bool = False, visualize: bool = False):
        self.av_table = AttributeValueTable()
        self.version_spaces = []
        self.G0 = None
        self.positive_count = 0
        self.negative_buffer = []
        self.need_to_prune = False
        self.p = []
        self.n = []
        # self.vectorizer = DictVectorizer(sparse=True)
        self.av_pairs = defaultdict(str)
        # self.prev_av_pairs = defaultdict(int)
        self.av_counter = 0

        # Timing
        self.time_to_process_pos = 0
        self.time_to_process_neg = 0

        # Model options
        self.logging_enabled = logging_enabled
        self.logger = VSSMLogger(logging_enabled=self.logging_enabled)
        self.visualize = visualize
        self.update_queue = None
        self._visualizer_thread = None

        if self.visualize:
            self._setup_visualizer()

    @time_monitor
    def ifit(self, instance: dict, label: bool):
        self.logger.log(f"Processing instance {instance}, Label: {'Pos' if label else 'Neg'}", 'STEP')
        start = time.time()

        curr_av_counter = int(f'{self.av_counter}')
        for key in instance:
            if not self.av_pairs[(key, instance[key])]:
                self.av_pairs[(key, instance[key])] = f"{self.av_counter}"
                self.av_counter += 1
        pad_value = self.av_counter - curr_av_counter

        # Modify all current hypotheses
        for vs in self.version_spaces:
            vs.S = np.pad(vs.S, (0, pad_value), mode='constant', constant_values=0)
            vs.G = np.pad(vs.G, (0, pad_value), mode='constant', constant_values=0)

        if self.G0 is None:
            self.G0 = np.ones(len(instance))
        else:
            self.G0 = np.pad(self.G0, (0, pad_value), mode='constant', constant_values=1)

        # Create new instance vector
        array = np.zeros(len(self.av_pairs))
        for key in instance:
            array[int(self.av_pairs[(key, instance[key])])] = 1
        instance = csr_matrix(array)

        if label:
            self._handle_positive(instance)
            self.time_to_process_pos += time.time() - start
        else:
            self._handle_negative(instance)
            self.time_to_process_neg += time.time() - start

    @time_monitor
    def fit(self, instances, labels):
        # print(instances)
        # print(labels)
        count = 0
        for i in tqdm(range(len(instances))):
            if count == 1000:
                break
                # pass
            self.ifit(instances[i], labels[i])
            count += 1
        print(json.dumps(self.get_state(), indent=4))
        time_monitor.report(self.time_to_process_pos, self.time_to_process_neg)

    @time_monitor
    def predict(self, instance):
        return any(vs.covers(instance) for vs in self.version_spaces)

    @time_monitor
    def score(self, instances, labels):
        return sum(self.predict(instances[i]) == labels[i] for i in range(len(instances))) / len(instances)

    @time_monitor
    def get_pyHTN_conds(self, identifiers=None) -> list:
        """
        Output the learned concept in Disjunctive Normal Form (dnf).
        Each concept is represented as a list of lists of literals, where each literal
        is of the form (attribute, identifier, value).

        @param identifiers: Optional dict mapping attributes to identifiers
        @return A list representing the DNF of the learned concept
        """
        # Initialize the DNF structure
        dnf = []

        # Process each version space (disjunction)
        for vs_idx, vs in enumerate(self.version_spaces):
            # Process each specific hypothesis in the S set
            for s_idx, s in enumerate(vs.s_set):
                # Create a conjunction of literals
                conjunction = []

                # Process each attribute in the hypothesis
                for attr, value in s.attributes.items():
                    identifier = None if identifiers is None else identifiers.get(attr, None)

                    # Skip wildcard values as they don't add constraints
                    if value == '?':
                        continue

                    # Handle probabilistic distributions for continuous attributes
                    if isinstance(value, NormalDistribution):
                        # For continuous attributes, we use the mean as a representative value
                        # and add both a lower and upper bound based on std deviation
                        if value.pos_count > 0:
                            # Lower bound: mean - std
                            lower_bound = value.pos_mean - value.pos_std
                            conjunction.append((attr, identifier, f">= {lower_bound:.2f}"))

                            # Upper bound: mean + std
                            upper_bound = value.pos_mean + value.pos_std
                            conjunction.append((attr, identifier, f"<= {upper_bound:.2f}"))
                    else:
                        # For categorical attributes, add the exact value constraint
                        conjunction.append((attr, identifier, value))

                # Only add non-empty conjunctions
                if conjunction:
                    dnf.append(conjunction)

        return dnf

    @time_monitor
    def get_lit_priorities(self, identifiers=None) -> list:
        """
        Output a prioritized list of literals, where each literal is of the form
        (attribute, identifier, value) and is assigned a priority in [0.0, 1.0].

        @param identifiers: Optional dict mapping attributes to identifiers
        @return A list of tuples (priority, literal) sorted by priority (highest first)
        """
        # First, we gather all literals and calculate their priorities
        literal_priorities = {}

        # Calculate the total number of specific hypotheses across all version spaces
        total_s_hypotheses = sum(len(vs.s_set) for vs in self.version_spaces)

        # Gather literals from G0 (most general hypotheses) - these get highest priority
        g0_literals = set()
        for g in self.G0:
            for attr, value in g.attributes.items():
                if value != '?':  # Skip wildcards
                    identifier = None if identifiers is None else identifiers.get(attr, None)
                    if isinstance(value, NormalDistribution):
                        # Use mean value for the literal
                        if value.pos_count > 0:
                            literal = (attr, identifier, f"≈ {value.pos_mean:.2f}")
                            g0_literals.add(literal)
                    else:
                        literal = (attr, identifier, value)
                        g0_literals.add(literal)

        # Assign priority 0.9 to G0 literals (high but not 1.0)
        for literal in g0_literals:
            literal_priorities[literal] = 0.9

        # Process version spaces from largest to smallest (by S set size)
        sorted_vs = sorted(enumerate(self.version_spaces), key=lambda x: len(x[1].s_set), reverse=True)

        for vs_idx, vs in sorted_vs:
            # Calculate version space weight based on its relative size
            vs_weight = len(vs.s_set) / total_s_hypotheses if total_s_hypotheses > 0 else 0

            # Process G set (more general hypotheses)
            for g_idx, g in enumerate(vs.g_set):
                for attr, value in g.attributes.items():
                    if value != '?':  # Skip wildcards
                        identifier = None if identifiers is None else identifiers.get(attr, None)
                        if isinstance(value, NormalDistribution):
                            # Use mean value for the literal
                            if value.pos_count > 0:
                                literal = (attr, identifier, f"≈ {value.pos_mean:.2f}")
                                # G set literals get higher priority (0.7-0.9) * version space weight
                                priority = 0.7 + (0.2 * vs_weight)
                                literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)
                        else:
                            literal = (attr, identifier, value)
                            # G set literals get higher priority (0.7-0.9) * version space weight
                            priority = 0.7 + (0.2 * vs_weight)
                            literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)

            # Process S set (more specific hypotheses)
            for s_idx, s in enumerate(vs.s_set):
                # Calculate specificity factor - earlier (more specific) hypotheses get higher priority
                specificity = 1.0 - (s_idx / len(vs.s_set)) if len(vs.s_set) > 1 else 1.0

                for attr, value in s.attributes.items():
                    if value != '?':  # Skip wildcards
                        identifier = None if identifiers is None else identifiers.get(attr, None)
                        if isinstance(value, NormalDistribution):
                            # Use mean value for the literal
                            if value.pos_count > 0:
                                literal = (attr, identifier, f"≈ {value.pos_mean:.2f}")
                                # S set literals get lower priority (0.4-0.7) * specificity * version space weight
                                priority = 0.4 + (0.3 * specificity * vs_weight)
                                literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)
                        else:
                            literal = (attr, identifier, value)
                            # S set literals get lower priority (0.4-0.7) * specificity * version space weight
                            priority = 0.4 + (0.3 * specificity * vs_weight)
                            literal_priorities[literal] = max(literal_priorities.get(literal, 0), priority)

        # For continuous attributes, also add literals based on attribute distributions
        for attr, dist in self.attribute_distributions.items():
            if dist.pos_count > 0 and dist.neg_count > 0:
                identifier = None if identifiers is None else identifiers.get(attr, None)

                # Calculate discrimination power as a measure of importance
                separation = abs(dist.pos_mean - dist.neg_mean) / (dist.pos_std + dist.neg_std) if (
                                                                                                           dist.pos_std + dist.neg_std) > 0 else 1.0
                separation = min(1.0, separation)  # Cap at 1.0

                # More discriminative attributes get higher priority
                priority = 0.5 + (0.5 * separation)

                # Add mean value literal
                mean_literal = (attr, identifier, f"≈ {dist.pos_mean:.2f}")
                literal_priorities[mean_literal] = max(literal_priorities.get(mean_literal, 0), priority)

                # Add range literals with slightly lower priority
                lower_literal = (attr, identifier, f">= {(dist.pos_mean - dist.pos_std):.2f}")
                upper_literal = (attr, identifier, f"<= {(dist.pos_mean + dist.pos_std):.2f}")

                literal_priorities[lower_literal] = max(literal_priorities.get(lower_literal, 0), priority * 0.9)
                literal_priorities[upper_literal] = max(literal_priorities.get(upper_literal, 0), priority * 0.9)

        # Convert to list of tuples and sort by priority (highest first)
        prioritized_literals = [(priority, literal) for literal, priority in literal_priorities.items()]
        prioritized_literals.sort(reverse=True)

        return prioritized_literals

    @time_monitor
    def get_state(self, current_instance=None, label=None):
        """Returns the current state of the algorithm in a serializable format."""
        state = {
            "g0": list(self.G0),
            "version_spaces": [
                {
                    "s_set": list(vs.S),
                    "g_set": list(vs.G)
                }
                for vs in self.version_spaces
            ],
            "current_instance": str(current_instance) if current_instance else "N/A",
            "label": 'Positive' if label is True else 'Negative' if label is False else "N/A"
        }
        return state

    @time_monitor
    def _minimally_specialize_hypothesis(self, h: Hypothesis, instance: dict):
        if not h.covers(instance):
            self.logger.log(f"Hypothesis {h} already excludes instance {instance}. No specialization needed.", 'INFO')
            return [h]
        self.logger.log(f"Specializing hypothesis {h} to exclude instance {instance}", 'ACTION')
        specializations = []
        # Minimal specializations iteratively increase the number specialized attributes until a set of specializations
        # is found. It begins by specializing one attribute at a time. Then, if none are found, it specializes every
        # combo of two attributes, then three, and so on.
        for i in range(len(self.av_table.attrs)):
            # print('specialization iteration: ', i)
            # Generate non-redundant combinations of every attribute starting with 1 at a time, then 2, and so on
            attr_combos = [i for i in combinations(self.av_table.attrs, i + 1)]
            # print('\tattr_combos: ', {i for i in attr_combos})
            for combo in attr_combos:
                # print('\t\tcurrent combo: ', combo)
                # 'new_values' represents the new specialized values the attributes will take on. Need to collect them
                # first then apply them to a hypothesis
                new_values = {}
                for attr in combo:
                    new_values[attr] = tuple()
                    # print('\t\t\tcurrent attr: ', attr)
                    # print(f'\t\t\th attribute: {h.attributes[attr]}')
                    # if h.attributes[attr] == '?' or isinstance(h.attributes[attr], set):
                    if h.attributes[attr] == '?' or len(h.attributes[attr]) > 0:
                        # print('\t\t\t\tvalues in table: ', self.av_table.table[attr])
                        for value in self.av_table.table[attr]:
                            # If a value previously observed for this attribute was observed in a
                            # positive example (self.av_table.table[attr][value]['p'] > 0) and the value is not
                            # the same as the value of the instance to exclude, then we can use the
                            # previously observed value
                            if self.av_table.table[attr][value]['p'] > 0 and value != instance[attr]:
                                new_values[attr] += tuple([value])

                # Create a new hypothesis where the attributes in the attribute combo are replaced with their
                # specialized values
                # if all(v for v in new_values.values()):
                for attr in new_values:
                    new_attributes = h.attributes.copy()
                    new_attributes[attr] = new_values[attr]
                    specializations.append(Hypothesis(new_attributes))
            # If the below code is aligned with (for combo in attr_combos), then it will generate a separate
            # specialization for each combination of attributes. If it is aligned with (for attr in combo) it will
            # only generate one possible specialization based on the first observed attribute combo
            if specializations:
                self.logger.log(f"Generated {len(specializations)} new specializations.", 'DEBUG')
                return specializations
                # break

        self.logger.log(f"Generated no new specializations.", 'DEBUG')
        return set()

    @time_monitor
    def _prune_less_general(self, hyp_set):
        """
        Prunes a set of hypotheses, removing any that are less general than another in the set.

        @param hyp_set: A set of Hypothesis objects.
        @return A pruned set of hypotheses, containing only the most general ones.
        """
        pruned_set = []
        unique_hyp = set(tuple(sorted(h.attributes.items(), key=lambda x: x[0])) for h in hyp_set)
        hyp_set = [Hypothesis(dict(i)) for i in unique_hyp]

        for h1 in hyp_set:
            if not any(h2.more_general_than(h1) for h2 in hyp_set if h1 != h2):
                pruned_set.append(h1)
        return pruned_set

    @time_monitor
    def _handle_positive(self, new_hyp: Hypothesis) -> None:

        for i, vs in enumerate(self.version_spaces):

            # Check if the hypothesis is already covered by an existing S-set. If so, do nothing.
            for j, s in enumerate(vs.S):
                self.logger.log(f"Checking if VS-{i + 1}|S-{j + 1} already covers instance...", 'INFO')
                if s.more_general_than(new_hyp):
                    self.logger.log(f"VS-{i + 1}| S-{j + 1} already covers instance. Doing nothing.", 'SUCCESS')
                    return
                # Check if a version space exists whose S set after including I is still consistent with its corresponding
                # G set
                self.logger.log(f"Checking if VS-{i + 1}| S-{j + 1} can be generalized...", 'INFO')
                candidate_hyp = Hypothesis.generalize(s, new_hyp, self.av_table.table)
                if vs.can_include(candidate_hyp):
                    self.logger.log(f"VS-{i + 1} can be generalized. Attempting merge.", 'SUCCESS')
                    vs.S.remove(s)
                    vs.S.append(candidate_hyp)
                    self.logger.log(f"Successfully merged hypothesis into VS-{i + 1}.", 'SUCCESS')
                    return

        if not self.version_spaces:
            self.logger.log("There are no existing version spaces.", 'ACTION')

        self.logger.log("No existing VS could be generalized. Creating a new version space.", 'ACTION')

        new_s = [new_hyp]
        new_g = [g for g in self.G0 if g.more_general_than(new_hyp)]
        if new_g:
            new_vs = VersionSpace(new_s, new_g, self.logging_enabled)
            self.version_spaces.append(new_vs)
            self.logger.log(f"Created new Version Space: {new_vs}", 'SUCCESS')
        else:
            self.logger.log("Cannot create a new VS: no hypothesis in G0 is general enough.", 'FAIL')

    @time_monitor
    def _handle_negative(self, instance: dict) -> None:
        self.logger.log(f"Handling negative instance: {instance}", 'ACTION')

        '''
        if self.positive_count == 0:
            # We haven't seen any positive examples yet, so buffer this negative one.
            self.logger.log(f"No positive examples seen yet. Buffering negative instance: {instance}", 'INFO')
            self.negative_buffer.append(instance)
            return
        '''

        self.logger.log("Specializing global G0 set.", 'ACTION')

        # new_G0 = set().union(*(self._minimally_specialize_hypothesis(g, instance) for g in self.G0))
        new_G0 = [i for g in self.G0 for i in self._minimally_specialize_hypothesis(g, instance)]
        # self.G0 = new_G0
        self.G0 = self._prune_less_general(new_G0)
        self.logger.log(f"Updated G0: {self.G0}", 'DEBUG')

        inconsistent_vss = []
        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"Specializing G-set of VS-{i + 1}.", 'ACTION')
            # new_G = set().union(*(self._minimally_specialize_hypothesis(g, instance) for g in vs.G))
            new_G = [i for g in vs.G for i in self._minimally_specialize_hypothesis(g, instance)]
            # vs.G = new_G
            vs.G = self._prune_less_general(new_G)
            self.logger.log(f"Updated G-set for VS-{i + 1}: {vs.G}", 'DEBUG')
            self.logger.log(f"Checking consistency of VS-{i + 1}...", 'INFO')
            if not vs.is_consistent():
                self.logger.log(f"VS-{i + 1} is now INCONSISTENT. Marking for split.", 'FAIL')
                inconsistent_vss.append(vs)

        if inconsistent_vss:
            self.logger.log(f"Splitting {len(inconsistent_vss)} inconsistent version space(s).", 'ACTION')
        for vs in inconsistent_vss:
            self.logger.log(f"Processing split for inconsistent VS: {vs}", 'INFO')
            self.version_spaces.remove(vs)
            # for disjunct in vs.S:
            #for disjunct in set().union(*(self._minimally_specialize_hypothesis(s, instance) for s in vs.S)):
            for disjunct in [i for s in vs.S for i in self._minimally_specialize_hypothesis(s, instance)]:
                self.logger.log(f"Re-processing disjunct {disjunct} from split VS.", 'ACTION')
                self._handle_positive(disjunct)

    @time_monitor
    def _update_g0(self, instance: dict, label: bool) -> None:
        if self.G0 is None:
            self.G0 = {Hypothesis({attr: '?' for attr in instance})}
            self.logger.log(f"G0 not found. Initialized to most general hypothesis: {self.G0}", "ACTION")

    @time_monitor
    def _setup_visualizer(self):
        """Initializes and starts the visualization servers in a background thread."""
        self.logger.log("Visualization enabled. Setting up servers...", "STEP")
        self.update_queue = queue.Queue()

        # Create the runner instance
        runner = VisualizationRunner(self.update_queue)

        # Start the HTTP and WebSocket servers
        runner.start_servers()

        # Start the broadcast loop in its own thread. This thread will pull
        # items from the queue and pass them to the server's event loop.
        self._visualizer_thread = threading.Thread(
            target=runner.broadcast_updates,
            daemon=True
        )
        self._visualizer_thread.start()
        self.logger.log("Visualization setup complete.", "SUCCESS")

    @time_monitor
    def close_visualizer(self):
        """Signals the visualization thread to shut down."""
        if self.visualize and self.update_queue:
            self.logger.log("Sending shutdown signal to visualizer.", "ACTION")
            self.update_queue.put('SHUTDOWN')
            if self._visualizer_thread:
                self._visualizer_thread.join(timeout=2)