import collections
import json
from itertools import combinations
import time
from typing import Any, Dict, Union
from tqdm import tqdm
from attribute_value_table import AttributeValueTable
from version_space import Hypothesis
from version_space import VersionSpace
from logger import VSSMLogger
from normal_distribution import NormalDistribution
from scoring import Scorecard
from time_monitor import time_monitor


class VSSM:
    """
    An incremental learning algorithm for disjunctive concepts with thorough logging.
    Assumptions:
    1. All examples have the same attributes (no examples of varying size)
    2. No missing values
    3.
    """

    def __init__(self, logging_enabled=False):
        self.av_table = AttributeValueTable()
        self.version_spaces = []
        self.g0 = None
        self.positive_count = 0
        self.negative_examples = []
        self.negative_buffer = []
        # Timing
        self.time_to_process_pos = 0
        self.time_to_process_neg = 0

        # Model options
        self.logging_enabled = logging_enabled
        self.logger = VSSMLogger(logging_enabled=self.logging_enabled)


    @time_monitor
    def get_state(self, current_instance=None, label=None):
        """Returns the current state of the algorithm in a serializable format."""
        state = {
            "g0": [str(g) for g in self.g0],
            "version_spaces": [
                {
                    "s_set": [str(s) for s in vs.S],
                    "g_set": [str(g) for g in vs.G]
                }
                for vs in self.version_spaces
            ],
            "current_instance": str(current_instance) if current_instance else "N/A",
            "label": 'Positive' if label is True else 'Negative' if label is False else "N/A",
            "num_version_spaces": len(self.version_spaces),
            "num_g0_hypotheses": len(self.g0),
            "avg_version_space_S_size": sum(len(vs.S) for vs in self.version_spaces) / len(self.version_spaces),
            "avg_version_space_G_size": sum(len(vs.G) for vs in self.version_spaces) / len(self.version_spaces),
        }
        return state

    @time_monitor
    def fit(self, instances, labels):
        for i in tqdm(range(len(instances))):
            self.ifit(instances[i], labels[i])

        time_monitor.report(self.time_to_process_pos, self.time_to_process_neg)

    @time_monitor
    def ifit(self, instance: dict, label: bool):
        instance = self._convert_instance_to_dict(instance)
        # If no positive examples have been encountered yet, hold off on processing negative examples
        if not label and self.positive_count == 0:
            self.negative_buffer.append(instance)

        self.logger.log(f"Processing instance {instance}, Label: {'Pos' if label else 'Neg'}", 'STEP')
        if self.g0 is None:
            self.g0 = {Hypothesis({attr: '?' for attr in instance})}
            self.logger.log(f"g0 not found. Initialized to most general hypothesis: {self.g0}", "ACTION")
        start = time.time()
        # Update attribute value table
        self.av_table.update(instance, label)

        instance = Hypothesis(instance)
        if label:
            self.positive_count += 1
            self._handle_positive(instance)
            self.time_to_process_pos += time.time() - start

            # Can process any backlogged positive examples now
            if self.negative_buffer:
                for n in self.negative_buffer:
                    self._handle_negative(n)
                    self.time_to_process_neg += time.time() - start
                self.negative_buffer = []
        else:
            self._handle_negative(instance)
            self.time_to_process_neg += time.time() - start

    @time_monitor
    def predict(self, instance):
        return any(vs.covers(instance) for vs in self.version_spaces)

    @time_monitor
    def multi_predict(self, instances, labels):
        """
                Calculates the accuracy and detailed performance metrics of the model on a test set.

                @param instances: A list of instance dictionaries.
                @param labels: A list of corresponding true boolean labels.
                @return: The accuracy as a float between 0.0 and 1.0.
                """
        if not instances:
            self.logger.log("Score called with empty instance list, returning 0.0.", level="INFO")
            return 0.0

        self.logger.log(f"Scoring model on {len(instances)} instances.", level="ACTION")

        # Initialize the scorecard
        scorecard = Scorecard(num_instances=len(instances))
        instances = [Hypothesis(instance) for instance in instances]

        for instance, true_label in zip(instances, labels):
            # Time the prediction for a single instance
            start_time = time.time()
            predicted_label = self.predict(instance)
            end_time = time.time()

            duration = end_time - start_time

            # Update the scorecard with the results for this instance
            scorecard.matrix.update(predicted_label, true_label)
            scorecard.total_time += duration

            if true_label:
                scorecard.positive_instance_time += duration
                scorecard.num_positive_instances += 1
            else:
                scorecard.negative_instance_time += duration
                scorecard.num_negative_instances += 1

        # Print the detailed scorecard to the console
        print(scorecard)
        self.logger.log(f"Scoring complete. Final accuracy: {scorecard.accuracy:.2%}", level="SUCCESS")

        # Return the accuracy as per the original function's contract
        return scorecard.accuracy

    def score(self, instance: Union[dict, list[dict]]):
        if isinstance(instance, list):
            instance = self._convert_instance_to_dict(instance)
        vs_disjunctions = self.get_pyHTN_conds()
        best_vs = 0
        score = 0

        for idx, conjunction in enumerate(vs_disjunctions):
            # literal[0] -> attr
            # literal[2] -> value
            # Proportion of literals that do not need to be generalized (in the specific concept) to accept the example
            # Compare value of literal to value of same attribute in instance, reward 1 if they are the same
            curr_score = sum(1 if literal[2] == instance[literal[0]] else 0 for literal in conjunction) / len(conjunction)
            if curr_score > score:
                score = curr_score
                best_vs = idx
        return score #, best_vs

    @time_monitor
    def get_pyHTN_conds(self) -> list:
        """
        Output the learned concept in Disjunctive Normal Form (dnf).
        Each concept is represented as a list of lists of literals, where each literal
        is of the form (attribute, identifier, value).

        @return A list representing the DNF of the learned concept
        """
        # Initialize the DNF structure
        dnf = []


        # Process each version space (disjunction)
        for vs_idx, vs in enumerate(self.version_spaces):
            # Process each specific hypothesis in the S set
            for s_idx, s in enumerate(vs.S):
                # Create a conjunction of literals
                conjunction = []

                # Process each attribute in the hypothesis
                for attr, value in s.attributes.items():
                    # identifier = None if identifiers is None else identifiers.get(attr, None)

                    if(isinstance(attr,str) and "." in attr):
                        identifier, temp_attr = attr.split('.')
                        # identifier = attr.split('.')[0]
                        # Skip wildcard values since they don't add constraints
                        # Skip id attributes
                        if value == '?' or temp_attr == 'id':
                            continue

                    # Handle probabilistic distributions for continuous attributes
                    if isinstance(value, NormalDistribution):
                        # For continuous attributes, we use the mean as a representative value
                        # and add both a lower and upper bound based on std deviation
                        if value.pos_count > 0:
                            # Lower bound: mean - std
                            lower_bound = value.pos_mean - value.pos_std
                            conjunction.append((attr, identifier, f">= {lower_bound:.2f}"))

                            # Upper bound: mean + std
                            upper_bound = value.pos_mean + value.pos_std
                            conjunction.append((attr, identifier, f"<= {upper_bound:.2f}"))
                    else:
                        # For categorical attributes, add the exact value constraint
                        conjunction.append((attr, identifier, value))

                # Only add non-empty conjunctions
                if conjunction:
                    dnf.append(conjunction)

        return dnf

    @time_monitor
    def get_lit_priorities(self) -> list:
        """
        Output a prioritized list of literals, where each literal is of the form
        (attribute, identifier, value) and is assigned a priority in [0.0, 1.0].
        The priority is based on a heuristic combining two factors:
        1. Type: Literals from general hypotheses (G-sets) are preferred.
        2. Positive Support: Literals strongly associated with positive examples (via the
           attribute-value table) are ranked higher.

        @return A list of tuples (priority, literal) sorted by priority (highest first)
        """
        if not self.version_spaces:
            return []

        # This will map: literal -> {'locations': set(), 'pos_count': int}
        literal_info = collections.defaultdict(lambda: {'locations': set(), 'pos_count': 0})

        def process_hypothesis(h: Hypothesis, vs_idx: int, hyp_type: str):
            """Helper to extract literals and their info from a hypothesis."""
            for attr, value in h.attributes.items():
                try:
                    identifier, temp_attr = attr.split('.', 1)
                except ValueError:
                    continue  # Skip attrs not in 'identifier.attribute' format

                if value == '?' or temp_attr == 'id':
                    continue

                if isinstance(value, NormalDistribution):
                    if value.pos_count > 0:
                        lower_bound = value.pos_mean - value.pos_std
                        upper_bound = value.pos_mean + value.pos_std
                        lit1 = (attr, identifier, f">= {lower_bound:.2f}")
                        lit2 = (attr, identifier, f"<= {upper_bound:.2f}")

                        # Use the distribution's positive count for both derived literals
                        literal_info[lit1]['locations'].add((vs_idx, hyp_type))
                        literal_info[lit1]['pos_count'] = max(literal_info[lit1]['pos_count'], value.pos_count)
                        literal_info[lit2]['locations'].add((vs_idx, hyp_type))
                        literal_info[lit2]['pos_count'] = max(literal_info[lit2]['pos_count'], value.pos_count)
                else:  # Categorical literal
                    literal = (attr, identifier, value)
                    # Look up the positive count from the attribute-value table
                    pos_count = self.av_table.table.get(attr, {}).get(value, {}).get('p', 0)
                    literal_info[literal]['locations'].add((vs_idx, hyp_type))
                    literal_info[literal]['pos_count'] = max(literal_info[literal]['pos_count'], pos_count)

        # Collect all unique literals and their associated info
        for vs_idx, vs in enumerate(self.version_spaces):
            for g in vs.G:
                process_hypothesis(g, vs_idx, 'G')
            for s in vs.S:
                process_hypothesis(s, vs_idx, 'S')

        # Calculate a priority score for each unique literal
        prioritized_literals = []

        for literal, info in literal_info.items():
            # Type Score (G-set literals have higher priority)
            is_in_g = any(loc[1] == 'G' for loc in info['locations'])
            type_score = 1.0 if is_in_g else 0.5

            # Positive Support Score (how often it's in positive examples)
            positive_support_score = info['pos_count'] / self.av_table.num_positive if self.av_table.num_positive > 0 else 0

            # Combine scores with equal weighting (average).
            priority = (type_score + positive_support_score) / 2
            prioritized_literals.append((priority, literal))

        # Sort the list by priority from highest to lowest.
        prioritized_literals.sort(key=lambda x: x[0], reverse=True)

        return prioritized_literals

    @staticmethod
    @time_monitor
    def _convert_instance_to_dict(instance: Union[dict, list]):
        """
        If 'instance' is a list of dictionaries, it is converted to a single dictionary and each attribute is indexed by
        the identifier of each dictionary as 'identifier.attribute'.
        """
        if isinstance(instance, list):
            new_instance = {}
            for i in instance:
                for attr, value in i.items():
                    new_instance[f'{i["id"]}.{attr}'] = value
            instance = new_instance
        return instance

    @time_monitor
    def _minimally_specialize_hypothesis(self, h: Hypothesis, instance: Hypothesis):
        def _get_attrs_to_specialize():
            """
            Gets the attributes that can be specialized by checking which attributes have values
            (not equal to the current value of the instance hypothesis) that has appeared in positive examples
            """
            can_specialize = []
            for a, v in instance.attributes.items():
                for val in self.av_table.table[a]:
                    if val != v and self.av_table.table[a][val]['p'] > 0:
                        can_specialize.append(a)
                        break
            return can_specialize

        if not h.is_more_general_than(instance):
            self.logger.log(f"Hypothesis {h} already excludes instance {instance}. No specialization needed.", 'INFO')
            return [h]

        attributes_to_specialize = _get_attrs_to_specialize()
        if not attributes_to_specialize:
            self.logger.log(f"Can not specialize hypothesis {h} to exclude instance {instance}", 'ACTION')
            return set()

        self.logger.log(f"Specializing hypothesis {h} to exclude instance {instance}", 'ACTION')
        specializations = []
        # Minimal specializations iteratively increase the number specialized attributes until a set of specializations
        # is found. It begins by specializing one attribute at a time. Then, if none are found, it specializes every
        # combo of two attributes, then three, and so on.
        # for i in range(len(self.av_table.attrs)):
        for i in range(len(attributes_to_specialize)):
            # print('specialization iteration: ', i)
            # Generate non-redundant combinations of every attribute starting with 1 at a time, then 2, and so on
            attr_combos = [i for i in combinations(attributes_to_specialize, i + 1)]
            #print('\tattr_combos: ', {i for i in attr_combos})
            for combo in attr_combos:
                #print('\t\tcurrent combo: ', combo)
                # 'new_values' represents the new specialized values the attributes will take on. Need to collect them
                # first then apply them to a hypothesis
                new_values = {}
                for attr in combo:
                    #print('\t\t\tcurrent attr: ', attr)
                    #print(f'\t\t\th attribute: {h.attributes[attr]}')
                    # if h.attributes[attr] == '?' or isinstance(h.attributes[attr], set):
                    # Not doing this here, but to make more general, you need to also check that the instance value
                    # is also not a wildcard
                    if h.attributes[attr] == '?':
                        #print(f'\t\t\t\tattribute: ({attr}) is a wildcard.')
                        # If a value previously observed for this attribute was observed in a
                        # positive example (self.av_table.positive_values[attr]) and the value is not
                        # the same as the value of the instance to exclude, then we can use the
                        # previously observed value
                        possible_values = self.av_table.positive_values[attr] - {instance.attributes[attr]}
                        #print(f'\t\t\t\tpossible_values: {possible_values} | Positive values so far: {self.av_table.positive_values[attr]}')
                        if possible_values:
                            new_values[attr] = possible_values.pop()
                        # If any attribute in the current combination of attributes is unable to be specialized, the
                        # combination should be discarded for the next
                        else:
                            break
                else:
                    # Create a new hypothesis where the attributes in the attribute combo are replaced with their
                    # specialized values
                    if new_values:
                        new_attributes = h.attributes.copy()
                        for specialized_attr in new_values:
                            new_attributes[specialized_attr] = new_values[specialized_attr]
                        specializations.append(Hypothesis(new_attributes))
                        self.logger.log(f"Generated new specialization: {new_attributes}", 'DEBUG')
                    # If the below code is aligned with (for combo in attr_combos), then it will generate a separate
                    # specialization for each combination of attributes. If it is aligned with (for attr in combo) it will
                    # only generate one possible specialization based on the first observed attribute combo
                    if specializations:
                        self.logger.log(f"Generated {len(specializations)} new specializations.", 'DEBUG')
                        return specializations
                        # break

        self.logger.log(f"Generated no new specializations.", 'DEBUG')
        return set()

    @time_monitor
    def _prune_less_general(self, hyp_set):
        """
        Prunes a set of hypotheses, removing any that are less general than another in the set.

        @param hyp_set: A set of Hypothesis objects.
        @return A pruned set of hypotheses, containing only the most general ones.
        """
        pruned_set = []
        # Doing this step because the Hypothesis.is_more_general_than function will consider to identical hypotheses
        # as more general than the other
        unique_hyp = set(tuple(sorted(h.attributes.items(), key=lambda x: x[0])) for h in hyp_set)
        hyp_set = [Hypothesis(dict(i)) for i in unique_hyp]

        for h1 in hyp_set:
            if not any(h2.is_more_general_than(h1) for h2 in hyp_set if h1 != h2):
                pruned_set.append(h1)
        return pruned_set

    @time_monitor
    def _handle_positive(self, new_hyp: Hypothesis) -> None:

        # Check if a version space's S set already covers new_hyp
        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"Checking if VS-{i + 1} already covers instance...", 'INFO')
            if vs.covers(new_hyp):
                self.logger.log(f"VS-{i + 1} already covers instance. Doing nothing.", 'SUCCESS')
                return
            # Check if S set hypotheses can be generalized to include new_hyp
            self.logger.log(f"Checking if VS-{i + 1} can be generalized...", 'INFO')
            for s in vs.S:
                candidate_hyp = Hypothesis.generalize(s, new_hyp)
                if vs.can_include(candidate_hyp):
                    self.logger.log(f"VS-{i + 1} can be generalized. Attempting merge.", 'SUCCESS')
                    vs.S.remove(s)
                    vs.S.append(candidate_hyp)
                    self.logger.log(f"Successfully merged hypothesis into VS-{i + 1}.", 'SUCCESS')
                    return

        # Not a part of the original algorithm. This is just used for logging
        if not self.version_spaces:
            self.logger.log("There are no existing version spaces.", 'ACTION')

        self.logger.log("No existing VS could be generalized. Creating a new version space.", 'ACTION')

        # Otherwise, create new version space where new_hyp comprises the S set and the G set is all hypotheses in g0
        # that are more general than S
        new_s = [new_hyp]
        new_g = [g for g in self.g0 if g.is_more_general_than(new_hyp)]

        if not new_g:
            self.logger.log("Cannot create new VS: no hypothesis in g0 is general enough. Updating g0.", 'FAIL')
            # This is where we call our new function
            self._find_and_add_g0_hypothesis(new_hyp)
            # Now that g0 is updated, we try again to find the general hypotheses.
            new_g = [g for g in self.g0 if g.is_more_general_than(new_hyp)]

        if new_g:
            new_vs = VersionSpace(new_s, new_g, self.logging_enabled)
            self.version_spaces.append(new_vs)
            self.logger.log(f"Created new Version Space: {new_vs}", 'SUCCESS')
        else:
            self.logger.log("Cannot create a new VS: no hypothesis in g0 is general enough.", 'FAIL')

    @time_monitor
    def _handle_negative(self, instance_hypothesis: Hypothesis) -> None:
        self.negative_examples.append(instance_hypothesis)
        self.logger.log(f"Handling negative instance: {instance_hypothesis}", 'ACTION')
        self.logger.log("Specializing global g0 set.", 'ACTION')

        # new_g0 = set().union(*(self._minimally_specialize_hypothesis(g, instance) for g in self.g0))
        new_g0 = [i for g in self.g0 for i in self._minimally_specialize_hypothesis(g, instance_hypothesis)]
        # self.g0 = new_g0
        self.g0 = self._prune_less_general(new_g0)
        self.logger.log(f"Updated g0: {self.g0}", 'DEBUG')

        inconsistent_vss = []
        for i, vs in enumerate(self.version_spaces):
            self.logger.log(f"Specializing G-set of VS-{i + 1}.", 'ACTION')
            # new_G = set().union(*(self._minimally_specialize_hypothesis(g, instance) for g in vs.G))
            new_G = [i for g in vs.G for i in self._minimally_specialize_hypothesis(g, instance_hypothesis)]
            # vs.G = new_G
            vs.G = self._prune_less_general(new_G)
            self.logger.log(f"Updated G-set for VS-{i + 1}: {vs.G}", 'DEBUG')
            self.logger.log(f"Checking consistency of VS-{i + 1}...", 'INFO')
            if not vs.is_consistent():
                self.logger.log(f"VS-{i + 1} is now INCONSISTENT. Marking for split.", 'FAIL')
                inconsistent_vss.append(vs)

        if inconsistent_vss:
            self.logger.log(f"Splitting {len(inconsistent_vss)} inconsistent version space(s).", 'ACTION')
        for vs in inconsistent_vss:
            self.logger.log(f"Processing split for inconsistent VS: {vs}", 'INFO')
            self.version_spaces.remove(vs)
            # for disjunct in vs.S:
            #for disjunct in set().union(*(self._minimally_specialize_hypothesis(s, instance) for s in vs.S)):
            for disjunct in [i for s in vs.S for i in self._minimally_specialize_hypothesis(s, instance_hypothesis)]:
                self.logger.log(f"Re-processing disjunct {disjunct} from split VS.", 'ACTION')
                self._handle_positive(disjunct)

    @time_monitor
    def _find_and_add_g0_hypothesis(self, positive_instance: Hypothesis):
        """
        Finds a new g0 hypothesis when existing ones are insufficient.
        It works by starting with the positive_instance and generalizing it as much
        as possible to ensure it never covers any known negative examples.
        """
        self.logger.log(f"Attempting to find a new g0 hypothesis for {positive_instance}", 'ACTION')

        candidate_g = Hypothesis(positive_instance.attributes.copy())

        # Try to generalize each attribute
        for attr in candidate_g.attributes:
            original_value = candidate_g.attributes[attr]
            if original_value == '?':
                continue

            # Tentatively generalize the attribute
            candidate_g.attributes[attr] = '?'

            # Check if this generalization is valid (i.e., doesn't cover any negative examples)
            is_valid_generalization = True
            for neg_ex in self.negative_examples:
                if candidate_g.is_more_general_than(neg_ex):
                    # This generalization is invalid because it covers a negative example
                    is_valid_generalization = False
                    break

            if is_valid_generalization:
                # The generalization was successful, keep the '?'
                self.logger.log(f"Successfully generalized attribute '{attr}' for new g0 candidate.", 'DEBUG')
            else:
                # Revert the change
                candidate_g.attributes[attr] = original_value
                self.logger.log(f"Could not generalize attribute '{attr}'; it would cover a negative example.", 'DEBUG')

        # After checking all attributes, candidate_g is the most general hypothesis
        # that covers positive_instance and excludes all negative_examples
        self.logger.log(f"Generated new g0 hypothesis: {candidate_g}. Adding to g0 set.", 'SUCCESS')
        self.g0.append(candidate_g)
        return True
