from dataclasses import dataclass
from scipy.stats import norm
import math



@dataclass
class NormalDistribution:
    """Class to represent continuous attributes using probability distributions."""
    # Parameters for positive examples
    pos_mean: float = 0.0
    pos_std: float = 1.0
    pos_count: int = 0

    # Parameters for negative examples
    neg_mean: float = 0.0
    neg_std: float = 1.0
    neg_count: int = 0

    # Minimum and maximum values observed
    min_val: float = float('inf')
    max_val: float = float('-inf')

    # Threshold for classifying (likelihood ratio)
    threshold: float = 1.0  # Default equal weighting

    def update(self, value: float, is_positive: bool) -> None:
        """
        Update the distribution with a new value. Uses online updating of mean and variance.

        @param value: The new value to incorporate
        @param is_positive: Whether this value comes from a positive or negative example
        """
        # Update min and max bounds
        self.min_val = min(self.min_val, value)
        self.max_val = max(self.max_val, value)

        if is_positive:
            # Update positive distribution
            self.pos_count += 1
            if self.pos_count == 1:
                self.pos_mean = value
                self.pos_std = 1.0  # Start with a default std
            else:
                # Online mean and variance update
                old_mean = self.pos_mean
                self.pos_mean = old_mean + (value - old_mean) / self.pos_count

                # Welford's algorithm for online variance
                if self.pos_count > 1:
                    old_var = self.pos_std ** 2
                    new_var = old_var + ((value - old_mean) * (value - self.pos_mean) - old_var) / self.pos_count
                    self.pos_std = max(math.sqrt(new_var), 0.1)  # Minimum std to avoid divide by zero
        else:
            # Update negative distribution
            self.neg_count += 1
            if self.neg_count == 1:
                self.neg_mean = value
                self.neg_std = 1.0  # Start with a default std
            else:
                # Online mean and variance update
                old_mean = self.neg_mean
                self.neg_mean = old_mean + (value - old_mean) / self.neg_count

                # Welford's algorithm for online variance
                if self.neg_count > 1:
                    old_var = self.neg_std ** 2
                    new_var = old_var + ((value - old_mean) * (value - self.neg_mean) - old_var) / self.neg_count
                    self.neg_std = max(math.sqrt(new_var), 0.1)  # Minimum std to avoid divide by zero

    def likelihood_ratio(self, value: float) -> float:
        """
        Calculate the likelihood ratio P(value|positive) / P(value|negative).

        @param value: Value to evaluate
        @return The likelihood ratio (> 1 means more likely positive, < 1 means more likely negative)
        """
        # Handle case where we have no examples of one class
        if self.pos_count == 0:
            return 0.0  # Assume negative
        if self.neg_count == 0:
            return float('inf')  # Assume positive

        # Calculate likelihoods
        pos_likelihood = norm.pdf(value, self.pos_mean, self.pos_std)
        neg_likelihood = norm.pdf(value, self.neg_mean, self.neg_std)

        # Avoid division by zero
        if neg_likelihood == 0:
            return float('inf')

        return pos_likelihood / neg_likelihood

    def classify(self, value: float) -> bool:
        """
        Classify a value as positive or negative based on likelihood ratio.

        @param value: The value to classify
        @return True if classified as positive, False if negative
        """
        return self.likelihood_ratio(value) >= self.threshold

    def contains(self, value: float) -> bool:
        """
        Check if the value would be classified as belonging to this distribution.

        @param value: The value to check
        @return True if classified as positive, False otherwise
        """
        return self.classify(value)

    def overlaps(self, other: 'NormalDistribution') -> bool:
        """
        Check if two distributions overlap 'significantly'.

        @param other: Another distribution to compare with
        @return True if distributions overlap significantly
        """
        # Check if the ranges overlap
        if self.max_val < other.min_val or self.min_val > other.max_val:
            return False

        # Check middle point between the means
        midpoint = (self.pos_mean + other.pos_mean) / 2

        # If both distributions classify the midpoint the same way, they overlap
        return self.classify(midpoint) == other.classify(midpoint)

    def __str__(self) -> str:
        """String representation of the distribution."""
        return (f"ProbDist[pos(μ={self.pos_mean:.2f},σ={self.pos_std:.2f},n={self.pos_count}), "
                f"neg(μ={self.neg_mean:.2f},σ={self.neg_std:.2f},n={self.neg_count}), "
                f"range=[{self.min_val:.2f},{self.max_val:.2f}]]")

    def __eq__(self, other: object) -> bool:
        """Equality comparison."""
        if not isinstance(other, NormalDistribution):
            return False

        # Two distributions are equal if their parameters are very close
        return (abs(self.pos_mean - other.pos_mean) < 1e-5 and
                abs(self.pos_std - other.pos_std) < 1e-5 and
                self.pos_count == other.pos_count and
                abs(self.neg_mean - other.neg_mean) < 1e-5 and
                abs(self.neg_std - other.neg_std) < 1e-5 and
                self.neg_count == other.neg_count)

    def __hash__(self) -> int:
        """Hash function for use in sets and dictionaries."""
        return hash((round(self.pos_mean, 5),
                     round(self.pos_std, 5),
                     self.pos_count,
                     round(self.neg_mean, 5),
                     round(self.neg_std, 5),
                     self.neg_count))

    def __json__(self):
        return {
            "pos_mean": self.pos_mean,
            "pos_std": self.pos_std,
            "pos_count": self.pos_count,
            "neg_mean": self.neg_mean,
            "neg_std": self.neg_std,
            "neg_count": self.neg_count,
            "min_val": self.min_val,
            "max_val": self.max_val,
            "threshold": self.threshold
        }