from collections import defaultdict
from itertools import combinations
import time
from tqdm import tqdm
from time_monitor import time_monitor
from logger import VSSMLogger
from uuid import uuid4
import re
import json
from scoring import ConfusionMatrix, Scorecard
from utils import pad_bit_vector


class Hypothesis:
    def __init__(self, vector: int):
        self.vector = vector
        self.id = str(uuid4())
        self.specializations = set()

    @time_monitor
    def is_more_general_than(self, other: 'Hypothesis') -> bool:
        return (self.vector & other.vector) == other.vector

    def bit_count(self) -> int:
        return self.vector.bit_count()

    def __eq__(self, other):
        return isinstance(other, Hypothesis) and self.vector == other.vector

    def __hash__(self):
        return hash(self.vector)

    def __repr__(self):
        return f"H({bin(self.vector)})"

    def __and__(self, other: 'Hypothesis') -> int:
        return self.vector & other.vector


class VersionSpace:
    def __init__(self, s_set: set, g_set: set):
        self.S = s_set; self.G = g_set

    @time_monitor
    def is_consistent(self) -> bool:
        if not self.S or not self.G: return False
        for s_hyp in self.S:
            if not any(g_hyp.is_more_general_than(s_hyp) for g_hyp in self.G): return False
        return True

    def __repr__(self):
        return f"VS(S={self.S}, G={self.G})"

class TrieNode:
    """A node in the Trie for storing bitmask hypotheses."""
    __slots__ = ['children', 'is_hypothesis'] # Use slots for memory efficiency
    def __init__(self):
        self.children: dict[int, 'TrieNode'] = {}
        self.is_hypothesis = False


class VSSM:
    """
    An incremental learning algorithm for disjunctive concepts with thorough logging.
    """

    def __init__(self, logging_enabled=True):
        self.version_spaces = []
        self.g0: set[Hypothesis] = set()
        self.av_to_bit_idx = {}
        self.attr_to_bit_indices = defaultdict(set)
        self.bit_idx_to_av = {}
        self.num_bits = 0
        # Instantiate the logger
        self.logger = VSSMLogger(logging_enabled=logging_enabled)

    @time_monitor
    def get_state(self, current_instance=None, label=None):
        """Returns the current state of the algorithm in a serializable format."""
        state = {
            "g0": [str(g) for g in self.g0],
            "version_spaces": [
                {
                    "s_set": [str(s) for s in vs.S],
                    "g_set": [str(g) for g in vs.G]
                }
                for vs in self.version_spaces
            ],
            "current_instance": str(current_instance) if current_instance else "N/A",
            "label": 'Positive' if label is True else 'Negative' if label is False else "N/A",
            "num_version_spaces": len(self.version_spaces),
            "num_g0_hypotheses": len(self.g0),
            "avg_version_space_S_size": sum(len(vs.S) for vs in self.version_spaces) / len(self.version_spaces),
            "avg_version_space_G_size": sum(len(vs.G) for vs in self.version_spaces) / len(self.version_spaces),
        }
        return state

    @time_monitor
    def _update_bit_indices(self, instance: dict) -> int:
        old_num_bits = int(f'{self.num_bits}')
        # self.num_bits += sum(1 if (attr, value) not in self.av_to_bit_idx else 0 for attr, value in instance.items())

        # num_new_bits = 0
        for attr, value in instance.items():
            if (attr, value) not in self.av_to_bit_idx:
                # num_new_bits += 1
                self.num_bits += 1  # Increment first to make new bit index num_bits - 1
                # e.g. num bits: 5 (total # bits after adding new attr/value combos), num old bits: 3
                # 1. 0b101 : new_bit_idx = (5 - 1) - (3 + 1) = 4 - 4 = 0 -> 0b101x
                # 2.
                # 3.
                # new_bit_idx = (self.num_bits - 1) - (old_num_bits + num_new_bits)
                new_bit_idx = self.num_bits - 1
                self.av_to_bit_idx[(attr, value)] = new_bit_idx
                self.bit_idx_to_av[new_bit_idx] = (attr, value)
                self.attr_to_bit_indices[attr].add(new_bit_idx)
                self.logger.log(f"New feature seen: ('{attr}', '{value}') -> mapped to bit {new_bit_idx}",
                                level="DEBUG")
        return self.num_bits - old_num_bits

    @time_monitor
    def _expand_hypotheses(self, bit_pad_width: int):

        if bit_pad_width <= 0:
            return

        self.logger.log(
            f"Feature space expanded. Updating general hypotheses by {bit_pad_width} bits to {self.num_bits} bits.",
            level="ACTION")

        self.logger.log(f"G0 before expansion: {self.g0}", level="DEBUG")
        # Instantiate global G
        if not self.g0:
            self.g0.add(Hypothesis(pad_bit_vector(bit_vector=0, pad_width=bit_pad_width, pad_with_one=True)))
            self.logger.log(f"Initialized G0: {self.g0}", level="SUCCESS")
        else:
            # Only need to do this once since all G set hypotheses come from G0
            self.g0 = {Hypothesis(pad_bit_vector(g.vector, bit_pad_width, pad_with_one=True)) for g in self.g0}
        self.logger.log(f"G0 after expansion: {self.g0}", level="SUCCESS")

        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"Expanding S-set of VS-{i + 1}: {vs.S}", level="DEBUG")
            vs.S = {Hypothesis(pad_bit_vector(s.vector, bit_pad_width, pad_with_one=False)) for s in vs.S}
            self.logger.log(f"VS-{i + 1} S-set after expansion: {vs.S}", level="DEBUG")
            self.logger.log(f"Expanding G-set of VS-{i + 1}: {vs.G}", level="DEBUG")
            vs.G = {Hypothesis(pad_bit_vector(g.vector, bit_pad_width, pad_with_one=True)) for g in vs.G}
            self.logger.log(f"VS-{i + 1} G-set after expansion: {vs.G}", level="DEBUG")

        self.logger.log(f"All S hypotheses expanded to {self.num_bits} bits: "
                        f"{all(s.vector.bit_length() == self.num_bits for vs in self.version_spaces for s in vs.S)}",
                        level="SUCCESS")
        self.logger.log(f"All G hypotheses expanded to {self.num_bits} bits: "
                        f"{all(g.vector.bit_length() == self.num_bits for vs in self.version_spaces for g in vs.G)}",
                        level="SUCCESS")

        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"(VS-{i+1}) S-Set: {vs.S} | G-Set: {vs.G}", level="DEBUG")

    @time_monitor
    def _instance_to_hypothesis(self, instance: dict) -> Hypothesis:
        vector = 0
        for attr, value in instance.items():
            bit_idx = self.av_to_bit_idx[(attr, value)]
            right_to_left_index = (self.num_bits - 1) - bit_idx
            vector |= (1 << right_to_left_index)
            # vector |= (1 << bit_idx)
        return Hypothesis(vector)

    @time_monitor
    def _prune_less_general(self, hyp_set: set['Hypothesis']) -> set['Hypothesis']:
        if len(hyp_set) < 2:
            return hyp_set

        # 1. Preprocessing: Sort by generality (descending bit count).
        # This is the key to making the algorithm efficient.
        sorted_hypotheses = sorted(list(hyp_set), key=lambda h: h.vector.bit_count(), reverse=True)

        root = TrieNode()
        final_pruned_set = set()
        num_bits = self.num_bits  # Assuming self.num_bits is available

        for h in sorted_hypotheses:
            # 2. Query Phase: Check if h is subsumed by anything already in the trie.
            if not self._is_subsumed_by_trie(root, h.vector, num_bits):
                # 3. If not subsumed, it's a maximal element. Add it to our results
                #    and insert it into the trie for future comparisons.
                self._insert_into_trie(root, h.vector, num_bits)
                final_pruned_set.add(h)

        return final_pruned_set

    @time_monitor
    def _prune_less_general2(self, hyp_set: set['Hypothesis']) -> set['Hypothesis']:
        """
        Filters a list of vectors to keep only the most general ones using a sort-and-prune method.
        """
        '''
        if len(hyp_set) < 2:
            return hyp_set

        # Get unique vectors and sort by descending bit count.
        unique_vectors = sorted(list(hyp_set), key=lambda x: x.bit_count(), reverse=True)

        most_general_set = set()
        for candidate in unique_vectors:
            is_subsumed = False
            # Check if the candidate is subsumed by anything we've already accepted.
            for general_hyp in most_general_set:
                # is_more_general(general_hyp, candidate)
                if (general_hyp & candidate) == candidate:
                    is_subsumed = True
                    break

            # If not subsumed after checking all, it's a maximal element
            if not is_subsumed:
                most_general_set.add(candidate)

        return most_general_set
        '''


    @time_monitor
    def _is_subsumed_by_trie(self, node: TrieNode, vector: int, num_bits: int) -> bool:
        """
        Checks if the given vector is subsumed by any hypothesis in the trie.
        A vector 'v' is subsumed by 'trie_v' if (trie_v & v) == v.
        This means for every '1' bit in 'v', the corresponding bit in 'trie_v' must also be '1'.
        """
        current_node = node
        for i in range(num_bits - 1, -1, -1):
            bit = (vector >> i) & 1

            # If the current vector needs a '1' (bit == 1), but the trie path
            # only has a '0' child, then this path cannot be more general.
            # However, we must continue checking other paths.
            if bit == 1 and 0 in current_node.children:
                # We can potentially be subsumed by a path that takes the '0' branch here.
                # This is because the trie hypothesis would have a '1' where our vector has a '0'.
                # This part is subtle. A simpler model is needed.

                # Let's restart this logic, it's the most complex part.
                pass  # placeholder for refined logic below.

        # This initial logic is tricky. Let's use a recursive search, which is cleaner.
        return self._recursive_is_subsumed(node, vector, num_bits - 1)

    @time_monitor
    def _recursive_is_subsumed(self, node: TrieNode, vector: int, bit_idx: int) -> bool:
        """Recursively search the trie for a hypothesis that subsumes `vector`."""
        if node.is_hypothesis:
            # We've reached a full hypothesis in the trie. It's guaranteed to be
            # more general than or equal to the path we took, so it subsumes our vector.
            return True

        if bit_idx < 0:
            return False

        v_bit = (vector >> bit_idx) & 1

        # If the vector's bit is 0, a more general hypothesis can have a 0 or 1.
        # We must check both paths if they exist.
        if v_bit == 0:
            # Path 1: Check the '0' branch
            if 0 in node.children and self._recursive_is_subsumed(node.children[0], vector, bit_idx - 1):
                return True
            # Path 2: Check the '1' branch
            if 1 in node.children and self._recursive_is_subsumed(node.children[1], vector, bit_idx - 1):
                return True

        # If the vector's bit is 1, a more general hypothesis MUST have a 1.
        # We only need to check the '1' branch.
        else:  # v_bit == 1
            if 1 in node.children and self._recursive_is_subsumed(node.children[1], vector, bit_idx - 1):
                return True

        return False

    @time_monitor
    def _insert_into_trie(self, node: TrieNode, vector: int, num_bits: int):
        """Inserts a vector into the trie."""
        current_node = node
        for i in range(num_bits - 1, -1, -1):
            bit = (vector >> i) & 1
            if bit not in current_node.children:
                current_node.children[bit] = TrieNode()
            current_node = current_node.children[bit]
        current_node.is_hypothesis = True

    @time_monitor
    def fit(self, instances: list[dict], labels: list[bool]):
        self.logger.log(f"Starting training on {len(instances)} instances.", level="HEADER")
        start_time = time.time()
        for i in tqdm(range(len(instances))):
            self.ifit(instances[i], labels[i], i)
        duration = time.time() - start_time

        print(json.dumps(self.get_state(), indent=4))
        time_monitor.report()
        self.logger.log(f"--- Training Complete in {duration:.4f} seconds ---", level="HEADER")
        self.logger.log(f"Final learned concept (Disjunctive Normal Form):\n{self.to_dnf()}", level="SUCCESS")

    @time_monitor
    def ifit(self, instance: dict, label: bool, instance_index: int = None):
        self.logger.log(f"Processing {'Example ' + str(instance_index) if instance_index is not None else ''} "
                        f"[{'Pos' if label else 'Neg'}] instance {instance}", 'STEP')
        # Update mapping of bits to attributes and relevant inverse mapping dicts, as well as total num bits
        bit_pad_width = self._update_bit_indices(instance)

        # If new attr/value pairs have been observed, expand the general and specific hypotheses to include them
        self._expand_hypotheses(bit_pad_width)

        instance_h = self._instance_to_hypothesis(instance)
        if label:
            self.logger.log(f"Routing to POSITIVE handler with instance vector {instance_h}", level="ACTION")
            self._handle_positive(instance_h)
        else:
            self.logger.log(f"Routing to NEGATIVE handler with instance vector {instance_h}", level="ACTION")
            self.logger.log(f"AV to bit: {self.av_to_bit_idx} | attr_to_bit_indices: {self.attr_to_bit_indices} | bit_idx_to_av: {self.bit_idx_to_av} ", level="DEBUG")
            self._handle_negative(instance_h)

        # Prune at the end
        self.g0 = self._prune_less_general(self.g0)
        for vs in self.version_spaces:
            #vs.G = {g for g in vs.G if g in self.g0}
            vs.G = self._prune_less_general(vs.G)
            vs.S = self._prune_less_general(vs.S)

    @time_monitor
    def predict(self, instance: dict) -> bool:
        """
        Predicts the label for a single new, unseen instance.

        The prediction is positive (True) if the instance is covered by any
        hypothesis in any S-set of any version space.

        @param instance: A dictionary representing the instance.
        @return: A boolean (True for positive, False for negative).
        """
        self.logger.log(f"Predicting for instance: {instance}", level="DEBUG")

        # Ensure the model's feature map is aware of any new values in the instance.
        # This does not change the learned hypotheses.
        self._update_bit_indices(instance)

        # Convert the instance dictionary to its bitmask vector.
        instance_h = self._instance_to_hypothesis(instance)
        self.logger.log(f"Instance vector for prediction: {instance_h}", level="DEBUG")

        is_covered = any(s_hyp.is_more_general_than(instance_h)
                         for vs in self.version_spaces
                         for s_hyp in vs.S)


        self.logger.log(f"Prediction result: {'Positive' if is_covered else 'Negative'}", level="DEBUG")
        return is_covered

    @time_monitor
    def score(self, instances: list[dict], labels: list[bool]) -> float:
        """
        Calculates the accuracy and detailed performance metrics of the model on a test set.

        @param instances: A list of instance dictionaries.
        @param labels: A list of corresponding true boolean labels.
        @return: The accuracy as a float between 0.0 and 1.0.
        """
        if not instances:
            self.logger.log("Score called with empty instance list, returning 0.0.", level="INFO")
            return 0.0

        self.logger.log(f"Scoring model on {len(instances)} instances.", level="ACTION")

        # Initialize the scorecard
        scorecard = Scorecard(num_instances=len(instances))

        for instance, true_label in zip(instances, labels):
            # Time the prediction for a single instance
            start_time = time.time()
            predicted_label = self.predict(instance)
            end_time = time.time()

            duration = end_time - start_time

            # Update the scorecard with the results for this instance
            scorecard.matrix.update(predicted_label, true_label)
            scorecard.total_time += duration

            if true_label:
                scorecard.positive_instance_time += duration
                scorecard.num_positive_instances += 1
            else:
                scorecard.negative_instance_time += duration
                scorecard.num_negative_instances += 1

        # Print the detailed scorecard to the console
        print(scorecard)
        self.logger.log(f"Scoring complete. Final accuracy: {scorecard.accuracy:.2%}", level="SUCCESS")

        # Return the accuracy as per the original function's contract
        return scorecard.accuracy

    @time_monitor
    def _handle_positive(self, pos_h: Hypothesis):
        for i, vs in enumerate(self.version_spaces):
            if any(s.is_more_general_than(pos_h) for s in vs.S):
            #if all(s.is_more_general_than(pos_h) for s in vs.S):
                self.logger.log("Instance is already covered by an existing S-set. No action required.", level="INFO")
                return

            self.logger.log(f"Attempting to generalize VS-{i + 1}...", level="INFO")
            new_s_candidates = {(Hypothesis(s.vector | pos_h.vector), s) for s in vs.S}
            valid_new_s = [s for s in new_s_candidates if any(g.is_more_general_than(s[0]) for g in vs.G)]

            if valid_new_s:
                self.logger.log(f"VS-{i + 1} can be generalized. Original S: {vs.S}", level="SUCCESS")
                for new_h, old_h in valid_new_s:
                    vs.S.remove(old_h)
                    vs.S.add(new_h)
                # vs.S = self._prune_less_general(vs.S)
                self.logger.log(f"VS-{i + 1} updated with new S-set: {vs.S}", level="SUCCESS")
                return
            '''
            if any(g.is_more_general_than(pos_h) for g in vs.G):
                self.logger.log("Instance can be added to version space specific set.", level="INFO")
                vs.S.add(pos_h)
                return
            '''
        '''
        if any(s.is_more_general_than(pos_h) for vs in self.version_spaces for s in vs.S):
            self.logger.log("Instance is already covered by an existing S-set. No action required.", level="INFO")
            return

        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"Attempting to generalize VS-{i + 1}...", level="INFO")
            new_s_candidates = {(Hypothesis(s.vector | pos_h.vector), s) for s in vs.S}
            valid_new_s = [s for s in new_s_candidates if any(g.is_more_general_than(s[0]) for g in vs.G)]

            if valid_new_s:
                self.logger.log(f"VS-{i + 1} can be generalized. Original S: {vs.S}", level="SUCCESS")
                vs.S.remove(valid_new_s[0][1])
                vs.S.add(valid_new_s[0][0])
                # vs.S = self._prune_less_general(valid_new_s)
                self.logger.log(f"VS-{i + 1} updated with new S-set: {vs.S}", level="SUCCESS")
                return
        '''
        '''
        found_match = False
        if valid_new_s:
            self.logger.log(f"VS-{i + 1} can be generalized. Original S: {vs.S}", level="SUCCESS")
            vs.S = self._prune_less_general(valid_new_s)
            found_match = True

        elif any(g.is_more_general_than(pos_h) for g in vs.G):
            self.logger.log(f"Hypothesis {pos_h} can be added to VS-{i} as another specific hypothesis.", level="SUCCESS")
            # Otherwise, if it can just be added as another specific hypothesis, do so
            vs.S.add(pos_h)
            found_match = True

        if found_match:
            self.logger.log(f"VS-{i + 1} updated with new S-set: {vs.S}", level="SUCCESS")
            return
        '''

        self.logger.log("No existing VS could be generalized. Creating a new version space.", level="ACTION")
        new_s_set = {pos_h}
        new_g_set = {g for g in self.g0 if g.is_more_general_than(pos_h)}
        # new_vs = VersionSpace(new_s_set, self._prune_less_general(new_g_set))
        new_vs = VersionSpace(new_s_set, new_g_set)
        self.version_spaces.append(new_vs)
        self.logger.log(f"Created new VS: {new_vs}", level="SUCCESS")

    @time_monitor
    def _handle_negative(self, neg_h: Hypothesis):
        self.logger.log(f"Specializing G0. Before: {self.g0}", level="ACTION")
        old_g0 = self.g0
        new_g0 = self._specialize_set(self.g0, neg_h)


        # Need to use new specialized G0 set to iterate through G sets and replace hypotheses with specialized ones
        # This is what I have to do in lieu of directly modifying the objects
        for vs in self.version_spaces:
            for g0 in old_g0:
                if g0.specializations and g0 in vs.G:
                    vs.G.remove(g0)
                    vs.G.update(g0.specializations)
            vs.G = self._prune_less_general(vs.G)
        self.g0 = new_g0
        self.logger.log(f"Specialized G0. After: {self.g0}", level="SUCCESS")
        vss_to_remove = []
        disjuncts_to_reprocess = set()

        for i, vs in enumerate(self.version_spaces):
            # self.logger.log(f"Processing VS-{i + 1}: Specializing its G-set. Before: {vs.G}", level="ACTION")
            # vs.G = self._specialize_set(vs.G, neg_h)
            # self.logger.log(f"VS-{i + 1} G-set after specialization: {vs.G}", level="SUCCESS")

            if not vs.is_consistent():
                self.logger.log(f"VS-{i + 1} became INCONSISTENT. Marking for removal and splitting S-set.",
                                level="FAIL")
                vss_to_remove.append(vs)
                split_s = self._specialize_set(vs.S, neg_h)
                self.logger.log(f"Disjuncts created from split: {split_s}", level="INFO")
                disjuncts_to_reprocess.update(split_s)
            else:
                self.logger.log(f"VS-{i + 1} remains consistent.", level="INFO")

        if vss_to_remove:
            self.version_spaces = [vs for vs in self.version_spaces if vs not in vss_to_remove]
            self.logger.log(f"Removed {len(vss_to_remove)} inconsistent VS(s).", level="ACTION")

        if disjuncts_to_reprocess:
            self.logger.log(f"Reprocessing {len(disjuncts_to_reprocess)} disjuncts from split (Merge step).",
                            level="ACTION")
            for disjunct in sorted(list(disjuncts_to_reprocess), key=lambda h: h.vector):
                self._handle_positive(disjunct)

    @time_monitor
    def _specialize_set(self, hyp_set: set[Hypothesis], instance_to_exclude: Hypothesis, instantiate_new: bool=False) -> (set[Hypothesis], dict):
        self.logger.log(f"Specializing set {hyp_set} to exclude {instance_to_exclude}", level="DEBUG")
        new_set = set()
        for h in hyp_set:
            if not h.is_more_general_than(instance_to_exclude):
                new_set.add(h)
            else:
                for attr in self.attr_to_bit_indices.keys():
                    attr_mask = self._get_attribute_mask(attr)
                    inst_attr_val_bit = instance_to_exclude.vector & attr_mask
                    if (h.vector & inst_attr_val_bit) == inst_attr_val_bit:
                        if (h.vector & attr_mask) != inst_attr_val_bit:
                            specialized_hyp = Hypothesis(h.vector & ~inst_attr_val_bit)
                            new_set.add(specialized_hyp)
                            h.specializations.add(specialized_hyp)


        return self._prune_less_general(new_set)
        #return new_set

    @time_monitor
    def _get_attribute_mask(self, attr: str) -> int:
        mask = 0
        for bit_idx in self.attr_to_bit_indices[attr]:
            right_to_left_index = (self.num_bits - 1) - bit_idx
            mask |= (1 << right_to_left_index)
            # mask |= (1 << bit_idx)
        return mask

    @time_monitor
    # --- (to_dnf, predict, etc. remain the same) ---
    def to_dnf(self) -> str:
        # ... implementation from previous answer ...
        if not self.version_spaces: return "Concept is empty or undefined."
        disjuncts = []
        sorted_vss = sorted(self.version_spaces, key=lambda vs: min(s.vector for s in vs.S) if vs.S else 0)
        for vs in sorted_vss:
            for s_hyp in sorted(list(vs.S), key=lambda s: s.vector):
                conjuncts = []
                for attr, bit_indices in self.attr_to_bit_indices.items():
                    attr_mask = self._get_attribute_mask(attr)
                    hyp_attr_bits = s_hyp.vector & attr_mask
                    if hyp_attr_bits == 0 or hyp_attr_bits == attr_mask: continue
                    values = [str(self.bit_idx_to_av[b][1]) for b in bit_indices if (hyp_attr_bits >> b) & 1]
                    conjuncts.append(f"({attr} is {' or '.join(values)})")
                if conjuncts: disjuncts.append(" and ".join(conjuncts))
        return "\nOR\n".join(disjuncts) if disjuncts else "True (most general concept)"

    def reduce_version_spaces(self):
        """
        Attempts to simplify the learned concept by merging version spaces.

        This method iterates through all pairs of version spaces and tries
        to combine them into a single, more general version space that is still
        consistent. This process is repeated until no more merges are possible.

        This helps in 'refactoring' the DNF representation of the concept
        into a simpler form, e.g., (A and B) OR (A and C) -> A and (B or C).
        """
        self.logger.log("=" * 20 + " Starting Version Space Reduction " + "=" * 20, level="HEADER")

        merged_in_pass = True
        while merged_in_pass:
            merged_in_pass = False
            num_spaces = len(self.version_spaces)
            if num_spaces < 2:
                self.logger.log("Less than two version spaces, no reduction possible.", level="INFO")
                break

            # We need a way to find and combine pairs, then restart the loop
            # as the list of version spaces will change.
            i = 0
            while i < len(self.version_spaces):
                j = i + 1
                while j < len(self.version_spaces):
                    vs1 = self.version_spaces[i]
                    vs2 = self.version_spaces[j]

                    self.logger.log(f"Attempting to merge VS-{i + 1} and VS-{j + 1}...", level="ACTION")

                    # 1. Create the potential merged S-set by generalizing all pairs
                    #    of specific hypotheses from the two spaces.
                    merged_s_set_unpruned = {
                        Hypothesis(s1.vector | s2.vector)
                        for s1 in vs1.S for s2 in vs2.S
                    }
                    merged_s_set = self._prune_less_general(merged_s_set_unpruned)
                    self.logger.log(f"Candidate merged S-set: {merged_s_set}", level="DEBUG")

                    # 2. Create the potential merged G-set by intersecting the G-sets.
                    #    This represents the constraints that both spaces must satisfy.
                    merged_g_set_unpruned = vs1.G.intersection(vs2.G)
                    merged_g_set = self._prune_less_general(merged_g_set_unpruned)
                    self.logger.log(f"Candidate merged G-set: {merged_g_set}", level="DEBUG")

                    # 3. Check for validity (consistency) of the potential new VS.
                    #    The new VS must have a non-empty G-set, and every new 's'
                    #    must be covered by at least one new 'g'.
                    is_merge_valid = False
                    if merged_g_set:
                        # Check if all s in the new S-set are covered by the new G-set
                        if all(any(g.is_more_general_than(s) for g in merged_g_set) for s in merged_s_set):
                            is_merge_valid = True

                    if is_merge_valid:
                        self.logger.log(f"SUCCESS: Merge of VS-{i + 1} and VS-{j + 1} is valid.", level="SUCCESS")

                        # Create the new merged version space
                        new_vs = VersionSpace(merged_s_set, merged_g_set)
                        self.logger.log(f"Created new merged VS: {new_vs}", level="SUCCESS")

                        # Replace the two old version spaces with the new one.
                        # We remove the one with the higher index first to not mess up the list.
                        self.version_spaces.pop(j)
                        self.version_spaces.pop(i)
                        self.version_spaces.insert(i, new_vs)

                        # Set flag to restart the outer loop and reset indices
                        merged_in_pass = True
                        # Break from the inner loops to restart the scan
                        break
                    else:
                        self.logger.log(f"FAIL: Merge of VS-{i + 1} and VS-{j + 1} is not valid.", level="FAIL")
                        j += 1

                if merged_in_pass:
                    # If a merge happened, restart the scan from the beginning
                    i = 0
                    merged_in_pass = False  # Reset for the new full pass
                else:
                    # Only increment if no merge happened in the inner loop
                    i += 1

            if not merged_in_pass and not self.version_spaces:
                # This condition handles the case where all spaces are merged into one.
                # The while loop might exit early, so we need to set the flag correctly.
                # A more robust implementation would simply restart the loop. Let's adjust.
                # A simple flag is sufficient, we just need to re-scan.
                pass  # Loop will naturally terminate if no merges are found.

        self.logger.log("=" * 20 + " Version Space Reduction Complete " + "=" * 20, level="HEADER")
        self.logger.log(f"Number of version spaces reduced to: {len(self.version_spaces)}", level="SUCCESS")


if __name__ == '__main__':
    # (The main execution block remains the same)
    training_data = [
        ({'size': 'Large', 'color': 'Blue', 'shape': 'Circle'}, True),
        ({'size': 'Large', 'color': 'Red', 'shape': 'Circle'}, False),
        ({'size': 'Small', 'color': 'Red', 'shape': 'Circle'}, True),
        ({'size': 'Small', 'color': 'Red', 'shape': 'Triangle'}, False),
        ({'size': 'Large', 'color': 'Red', 'shape': 'Triangle'}, True),
        ({'size': 'Large', 'color': 'Blue', 'shape': 'Triangle'}, True),
    ]

    instances = [d[0] for d in training_data]
    labels = [d[1] for d in training_data]
    vssm_model = VSSM(logging_enabled=True)  # Enable logging
    vssm_model.fit(instances, labels)