import numpy as np
import torch
import torch.nn as nn
from .cutpoint_initializer import RandomInitializer
from .initializers import CutPointInitializer
import torch.nn.functional as F
from itertools import combinations
import flowtorch.distributions as dist


def _get_data_driven_bounds(X, assignments, rule_idx):
    """
    Calculates the tight bounding box for a rule based on its assigned data points.
    """
    assigned_mask = assignments == rule_idx
    if not torch.any(assigned_mask):
        return None

    assigned_data = X[assigned_mask]

    if assigned_data.dim() == 1:
        mins = torch.min(assigned_data)
        maxs = torch.max(assigned_data)
    else:
        mins = torch.min(assigned_data, dim=0).values
        maxs = torch.max(assigned_data, dim=0).values

    # Stack to get shape [n_features, 2]
    return torch.stack([mins, maxs], dim=1).to(X.device)


def _calculate_containment(box_container, box_contained):
    """
    Calculates the containment of box_contained within box_container.
    Containment is defined as: Intersection(A, B) / Volume(B).
    This measures what fraction of B's volume is inside A.

    Args:
        box_container (torch.Tensor): Bounding box of the container, shape [n_features, 2].
        box_contained (torch.Tensor): Bounding box of the contained, shape [n_features, 2].

    Returns:
        torch.Tensor: The containment value (scalar).
    """
    if box_container is None or box_contained is None:
        return torch.tensor(
            0.0, device=box_container.device if box_container is not None else "cpu"
        )

    # Calculate intersection volume
    intersection_lower = torch.max(box_container[:, 0], box_contained[:, 0])
    intersection_upper = torch.min(box_container[:, 1], box_contained[:, 1])
    intersection_dims = torch.clamp(intersection_upper - intersection_lower, min=0)
    intersection_volume = torch.prod(intersection_dims)

    # Calculate volume of the contained box
    contained_dims = torch.clamp(box_contained[:, 1] - box_contained[:, 0], min=0)
    contained_volume = torch.prod(contained_dims)

    if contained_volume < 1e-8:
        return torch.tensor(0.0, device=contained_volume.device)

    return intersection_volume / contained_volume


def _calculate_iou(box1, box2):
    """Calculates Intersection over Union for 1D intervals."""
    box1_lower, box1_upper = torch.min(box1), torch.max(box1)
    box2_lower, box2_upper = torch.min(box2), torch.max(box2)

    intersection_lower = torch.max(box1_lower, box2_lower)
    intersection_upper = torch.min(box1_upper, box2_upper)

    intersection = torch.clamp(intersection_upper - intersection_lower, min=0)

    box1_area = box1_upper - box1_lower
    box2_area = box2_upper - box2_lower

    union = box1_area + box2_area - intersection

    iou = intersection / (union + 1e-8)
    return iou


def _js_divergence(p, q):
    """Calculates Jensen-Shannon divergence for two discrete probability distributions."""
    m = 0.5 * (p + q)
    jsd = 0.5 * (
        F.kl_div(m.log(), p, reduction="sum") + F.kl_div(m.log(), q, reduction="sum")
    )
    return jsd


class GMMRemixer(nn.Module):
    """
    Remixes components of a pre-trained global GMM based on rule assignments.
    Can optionally include a background component.
    """

    def __init__(
        self,
        n_rules,
        n_gmm_components,
        initial_mixing_weights=None,
        diagonal=False,
        use_background_component=False,
    ):
        super().__init__()
        self.n_rules = n_rules  # Number of interpretable rules
        self.n_gmm_components = n_gmm_components
        self.use_background_component = use_background_component

        n_total_components = (
            self.n_rules + 1 if self.use_background_component else self.n_rules
        )

        # Shape: [n_gmm_components, n_total_components]
        if initial_mixing_weights is None:
            if diagonal:
                # Initialize interpretable components diagonally, background randomly
                init_weights = torch.randn(n_gmm_components, n_total_components) * 0.1
                if n_gmm_components > 0:
                    for j_rule_idx in range(self.n_rules):
                        i_gmm_idx = j_rule_idx % n_gmm_components
                        init_weights[i_gmm_idx, j_rule_idx] += 2.0
            else:
                init_weights = (
                    torch.rand(
                        [n_gmm_components, n_total_components], dtype=torch.float32
                    )
                    * 1
                )
        else:
            init_weights = initial_mixing_weights

        self.mixing_weights = nn.Parameter(init_weights, requires_grad=True)
        self.cached_densities = None

    def reorder_rules(self, order_indices):
        """
        Reorders the rule-specific mixing weights according to the provided indices.

        Args:
            order_indices (torch.Tensor or list): A tensor or list containing the new
                                                  order of rule indices.
        """
        with torch.no_grad():
            current_weights = self.mixing_weights.data

            # Separate the rule weights from the background weight if it exists
            rule_weights = current_weights[:, : self.n_rules]

            # Reorder the rule weights
            sorted_rule_weights = rule_weights[:, order_indices]

            if self.use_background_component:
                background_weights = current_weights[:, -1].unsqueeze(1)
                # Combine the sorted rule weights with the background weights
                new_weights = torch.cat(
                    [sorted_rule_weights, background_weights], dim=1
                )
            else:
                new_weights = sorted_rule_weights

            self.mixing_weights.data = new_weights

    def forward(self, rule_probs, component_densities, mean_reduce=True):
        """
        Args:
            rule_probs (torch.Tensor): Gating probabilities s_j(x).
                                       Shape: [batch_size, n_total_components]
            component_densities (torch.Tensor): Pre-calculated densities phi_m(y)
                                                for each global GMM component m.
                                                Shape: [batch_size, n_gmm_components]
            mean_reduce (bool): If True, returns the mean negative log-likelihood.
                                If False, returns the per-sample log-likelihood.

        Returns:
            If mean_reduce is False:
                torch.Tensor: Per-sample log-likelihood, log p(y|x). Shape: [batch_size]
                torch.Tensor: L1 sparsity penalty for the mixing weights (scalar).
            If mean_reduce is True:
                torch.Tensor: Mean negative log-likelihood for the batch (scalar).
                torch.Tensor: L1 sparsity penalty for the mixing weights (scalar).
        """
        # Normalize mixing weights across GMM components for each rule
        # norm_weights_jm = softmax(W_jm) where sum_m norm_weights_jm = 1
        # Shape: [n_gmm_components, n_total_components] -> after transpose [n_total_components, n_gmm_components]
        norm_weights = torch.softmax(self.mixing_weights, dim=0).T

        # Calculate expert densities p_j(y) = sum_m w_jm * phi_m(y)
        # Matrix multiplication: [n_total_components, n_gmm_components] @ [n_gmm_components, batch_size]
        # Result shape: [n_total_components, batch_size] -> after transpose [batch_size, n_total_components]
        expert_densities = (norm_weights @ component_densities.T).T

        # Calculate mixture density p(y|x) = sum_j s_j(x) * p_j(y)
        # Element-wise product and sum over components
        mixture_density = torch.sum(rule_probs * expert_densities, dim=1)
        log_likelihood = torch.log(mixture_density + 1e-9)

        # Calculate L1 sparsity penalty on the raw (unnormalized) mixing weights
        l1_penalty = torch.mean(torch.abs(self.mixing_weights))

        if mean_reduce:
            nll = -torch.mean(log_likelihood)
            return nll, l1_penalty
        else:
            return log_likelihood, l1_penalty

    def get_mixing_weights(self):
        """Returns the normalized mixing weights."""
        return torch.softmax(self.mixing_weights, dim=0).detach()

    def set_densities(self, densities):
        """Sets the cached densities for the GMM components."""
        self.cached_densities = densities


class Discretizing_Layer(nn.Module):
    """
    A PyTorch layer that learns discretizing intervals (predicates) for continuous features,
    using a provided initializer for the initial cut points.
    """

    def __init__(
        self,
        n_features: int,
        predicates_per_feature: int,
        data_limits: torch.Tensor,
        initializer: CutPointInitializer,
        temperature: float = 0.1,
        discrete_features=None,
    ):
        super().__init__()
        if not isinstance(data_limits, torch.Tensor):
            raise TypeError("data_limits must be a torch.Tensor")
        if data_limits.shape != (n_features, 2):
            raise ValueError(f"data_limits must have shape ({n_features}, 2)")

        self.temperature = temperature
        self.predicates_per_feature = predicates_per_feature
        self.limits = data_limits.detach().clone()

        init = initializer.initialize(n_features, predicates_per_feature, self.limits)
        self.cut_points = nn.Parameter(init, requires_grad=True)

        if discrete_features is not None:
            if len(discrete_features) != n_features:
                raise ValueError("Length of discrete_features must match n_features")
            self.is_discrete = discrete_features
        elif (
            hasattr(initializer, "X")
            and initializer.X is not None
            and isinstance(initializer.X, torch.Tensor)
            and initializer.X.ndim == 2
            and initializer.X.shape[1] == n_features
        ):
            X_data = initializer.X
            if n_features > 0:
                self.is_discrete = [
                    torch.unique(X_data[:, i]).numel() < 5 for i in range(n_features)
                ]
            else:
                self.is_discrete = []
        else:
            self.is_discrete = [False for _ in range(n_features)]

    def forward(self, x):
        """
        Calculates the fuzzy activation for each interval predicate.
        """
        # x has shape [batch_size, n_features]
        # self.cut_points has shape [n_features, 2, predicates_per_feature]

        # Unsqueeze x to allow broadcasting over predicates per feature
        # New shape: [batch_size, n_features, 1]
        x_expanded = x.unsqueeze(2)

        # Get lower and upper bounds for each predicate
        # Shape of each: [n_features, predicates_per_feature]
        lower_bound = self.cut_points[:, 0, :]
        upper_bound = self.cut_points[:, 1, :]

        # Activation for being above the lower bound (left side of the interval)
        # Result is broadcast to [batch_size, n_features, predicates_per_feature]
        activation_lower = torch.sigmoid((x_expanded - lower_bound) / self.temperature)

        # Activation for being below the upper bound (right side of the interval)
        activation_upper = torch.sigmoid((upper_bound - x_expanded) / self.temperature)

        output = activation_lower * activation_upper

        return output

    def fix_parameters(self):
        pass

    def get_cutpoints(self, scaler=None):
        """
        Returns the raw cutpoints, optionally unscaled.
        Args:
            scaler (object, optional): A fitted scaler to inverse_transform the cutpoints.
                                       If None, returns the original scaled cutpoints.

        Returns:
            numpy.ndarray: The cutpoints, shape [n_features, 2, predicates_per_feature].
        """
        cut_points_np = self.cut_points.data.detach().cpu().numpy()

        if scaler is not None:
            unscaled_cutpoints = np.zeros_like(cut_points_np)
            for i in range(self.predicates_per_feature):
                # Transpose to [n_features, 2] -> [2, n_features] to fit scaler
                scaled_pts = cut_points_np[:, :, i].T
                unscaled_pts = scaler.inverse_transform(scaled_pts)
                unscaled_cutpoints[:, :, i] = unscaled_pts.T
            return unscaled_cutpoints
        else:
            return cut_points_np


class And_Layer(nn.Module):
    def __init__(self, n_features, n_rules, epsilon=0.001):
        super().__init__()
        self.epsilon = epsilon
        self.n_rules = n_rules
        self.n_features = n_features
        learnable = n_features > 1
        self.and_weights = nn.Parameter(
            torch.rand([n_rules, n_features], dtype=torch.float32),
            requires_grad=True,
        )
        self.and_weights.data[:] = 1

        self.relu = nn.ReLU()
        self.disabled = False

    def forward(self, x):
        and_weights = self.relu(self.and_weights)

        # swap 1 and 2 axes of x
        x = x.permute(0, 2, 1)

        weight_sum = torch.sum(and_weights, dim=[1]) + 1e-8
        # geometric weight mean
        and_weights = and_weights.reshape(
            [
                -1,
            ]
        )
        eta = self.epsilon / weight_sum  # .detach()
        eta = eta.unsqueeze(1).expand(-1, x.shape[2])
        inverse_sum = (1 + eta) / (x + eta)

        inverse_sum = inverse_sum.reshape([x.shape[0], -1])
        weighted_inverse_sum = torch.multiply(inverse_sum, and_weights)
        weighted_inverse_sum = weighted_inverse_sum.reshape(
            [x.shape[0], self.n_rules, -1]
        )
        weighted_inverse_sum = torch.sum(weighted_inverse_sum, dim=[2]) + 1e-8

        res = weight_sum / (weighted_inverse_sum)

        # Calculate and return the L1 loss
        relu_weights = self.relu(self.and_weights)
        normalized_weights = relu_weights / torch.sum(relu_weights, dim=1, keepdim=True)
        l1_loss = torch.mean(torch.abs(normalized_weights))
        entropy_loss = -torch.sum(
            normalized_weights * torch.log(normalized_weights + 1e-8), dim=1
        )
        return res, entropy_loss

    def fix_parameters(self):
        return self.disabled


class SimpleMixtureRules(nn.Module):
    def __init__(
        self,
        X,
        n_components=2,
        temperature=0.2,
        epsilon=0.001,
        optimize_weights=True,
        component_configs=None,
        initializer=None,
        width_normalization_alpha=0.0,
        discrete_features=None,
        use_background_component: bool = False,
        background_epsilon: float = 0.1,
    ):
        super().__init__()

        data_limits = torch.tensor(
            [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])],
            dtype=torch.float32,
            device=X.device,
        )
        self.rules = nn.ModuleList()
        for _ in range(n_components):
            self.rules.append(
                SimpleNeuralAndFinder(
                    data_limits=data_limits,
                    temperature=temperature,
                    epsilon=epsilon,
                    initializer=initializer,
                    discrete_features=discrete_features,
                )
            )

        self.n_components = n_components
        self.width_normalization_alpha = width_normalization_alpha
        self.use_background_component = use_background_component
        self.background_epsilon = background_epsilon

    def set_temperature(self, temperature):
        for rule in self.rules:
            rule.discretizer.temperature = temperature

    def set_discrete_features(self, discrete_features):
        for rule in self.rules:
            rule.discretizer.is_discrete = discrete_features

    def get_disabled_rules(self):
        """
        Returns a list of boolean values indicating whether each rule is disabled.
        """
        return [rule.disabled for rule in self.rules]

    def forward(self, x):
        """
        Calculates gating probabilities for all components.
        If use_background_component is True, it returns K+1 probabilities,
        with the last one corresponding to the background component.
        """
        rule_probs = []
        l1_losses = []
        for rule in self.rules:
            prob, l1 = rule(x)
            prob = prob.squeeze()
            rule_probs.append(prob)
            if not rule.disabled:
                l1_losses.append(l1)

        # Stack to get [batch_size, n_components]
        raw_rule_probs = torch.stack(rule_probs, dim=1)

        # Calculate the mean L1 loss across all active rules
        total_l1_loss = (
            torch.mean(torch.stack(l1_losses))
            if l1_losses
            else torch.tensor(0.0, device=x.device)
        )

        if self.use_background_component:
            # Denominator is sum of rule activations + epsilon
            denominator = (
                raw_rule_probs.sum(dim=1, keepdim=True) + self.background_epsilon
            )

            # Gating probs for interpretable rules
            rule_gate_probs = raw_rule_probs / (denominator + 1e-8)

            # Gating prob for background component
            background_gate_prob = self.background_epsilon / (denominator + 1e-8)

            # Concatenate to get all gating probabilities
            final_probs = torch.cat([rule_gate_probs, background_gate_prob], dim=1)
        else:
            final_probs = raw_rule_probs / (
                raw_rule_probs.sum(dim=1, keepdim=True) + 1e-8
            )

        if torch.isnan(final_probs).any():
            print("Warning: NaN detected in final gating probabilities.")
            nan_mask = torch.isnan(final_probs)
            final_probs[nan_mask] = 1.0 / final_probs.shape[1]

        return final_probs, total_l1_loss

    def get_responsibilities(self, x):
        with torch.no_grad():
            rule_probs, _ = self.forward(x)
        return rule_probs.cpu().numpy()

    def forward_raw(self, x):
        rule_probs = []

        for rule in self.rules:
            # Each rule returns activation probability
            prob = rule(x)[0].squeeze()
            rule_probs.append(prob)
            if torch.isnan(prob).any():
                print("NaN in rule probabilities")
                print(rule_probs)
            if torch.isinf(prob).any():
                print("Inf in rule probabilities")
                print(rule_probs)
        # Stack to get [batch_size, n_components]
        raw_rule_probs = torch.stack(rule_probs, dim=1)
        return raw_rule_probs

    def compute_rule_means(self, X, Y):
        """Compute mean target value for each rule.

        Args:
            X: Input features tensor
            Y: Target values tensor

        Returns:
            means: List of mean target vectors for each rule
        """
        with torch.no_grad():
            if Y.ndim == 1:
                Y = Y.unsqueeze(1)
            rule_probs, _ = self.forward(X)

            means = []
            # Only compute means for the interpretable rules
            for i in range(self.n_components):
                weights = rule_probs[:, i]  # shape: [n_samples]
                total_weight = weights.sum()

                if total_weight == 0:
                    means.append(
                        torch.full((Y.shape[1],), float("inf"), device=Y.device)
                    )
                    continue

                # Weighted sum for each dimension of Y
                # weights.unsqueeze(1) has shape [n_samples, 1]
                # Y has shape [n_samples, y_dim]
                weighted_sum = (weights.unsqueeze(1) * Y).sum(dim=0)
                mean = weighted_sum / (total_weight + 1e-8)
                means.append(mean)

            return means

    def count_active_parameters(self, threshold=0.1):
        total_params = 0
        for rule in self.rules:
            if rule.disabled:
                continue
            rule_params = sum(p.numel() for p in rule.parameters() if p.requires_grad)

            total_params += rule_params

        return total_params

    def count_active_rules(self, threshold=0.1):
        """Count the number of active rules based on the and_weights threshold."""
        active_rules = 0
        for rule in self.rules:
            if rule.disabled:
                continue
            and_weights = rule.and_layer.and_weights[0]
            active_features = torch.where(and_weights > threshold)[0]
            if len(active_features) > 0:
                active_rules += 1

        return active_rules

    def sort_rules(self, X, Y):
        """
        Sorts rules in-place based on the mean of their target variable Y.
        It also returns the indices of the new sort order.

        Args:
            X (torch.Tensor): Input features tensor.
            Y (torch.Tensor): Target values tensor. For multidimensional Y, sorting is
                              based on the L2 norm of the mean vector.

        Returns:
            torch.Tensor: A tensor of indices representing the new order of the rules.
        """
        if Y.ndim == 1:
            Y = Y.unsqueeze(1)

        means = self.compute_rule_means(X, Y)

        mean_norms = []
        for m in means:
            if torch.isinf(m).any():
                mean_norms.append(float("inf"))
            else:
                mean_norms.append(torch.linalg.norm(m).item())

        sort_idx = torch.tensor(mean_norms).argsort()

        sorted_rules = [self.rules[i] for i in sort_idx]
        self.rules = nn.ModuleList(sorted_rules)

        return sort_idx

    def debug_print_cutpoints(
        self,
        feature_names=None,
        scaler=None,
        include_disabled=False,
        simple_format=False,
    ):
        """
        Prints the cutpoints for all rules.

        Args:
            feature_names (list, optional): Names of the features.
            scaler (object, optional): A fitted scaler to show unscaled values.
            include_disabled (bool): Whether to include disabled rules in the output.
            simple_format (bool): If True, prints in a compact, one-line-per-rule CSV format.
        """
        if feature_names is None and self.rules:
            n_features = self.rules[0].discretizer.cut_points.shape[0]
            feature_names = [f"Feature_{i}" for i in range(n_features)]

        output_lines = []
        if simple_format:
            for i, rule in enumerate(self.rules):
                if rule.disabled and not include_disabled:
                    continue
                status = "DISABLED" if rule.disabled else "ACTIVE"
                rule_str = rule.get_rule_string_for_csv(
                    feature_names=feature_names, scaler=scaler
                )
                output_lines.append(f'{i+1},{status},"{rule_str}"')
                # print(f'{i+1},{status},"{rule_str}"')
        else:
            for i, rule in enumerate(self.rules):
                if rule.disabled and not include_disabled:
                    continue

                status = "DISABLED" if rule.disabled else "ACTIVE"
                rule_str = rule.debug_print_cutpoints(
                    feature_names=feature_names, scaler=scaler
                )
                output_lines.append(f"Component {i+1} ({status}):\n{rule_str}")

        return "\n".join(output_lines)

    def fix_parameters(self):
        for rule in self.rules:
            rule.fix_parameters()

    def merge_adjacent_components(
        self,
        X,
        Y,
        density_model,
        density_model_type="gmm_remix",
        gmm_model=None,
        adjacency_tol=0.1,
        iou_threshold=0.8,
        density_jsd_threshold=0.1,
        verbose=True,
        use_data_driven_bounds=True,
        disabled_components=None,
    ):
        """
        Iteratively finds all adjacent rules and merges the best on in each pass.

        Args:
            X (torch.Tensor): Input data to calculate responsibilities.
            Y (torch.Tensor): Target data, used to determine the grid for comparison.
            density_model: The trained density model (GMMRemixer or FlowMixtureExperts).
            density_model_type (str): 'gmm_remix' or 'flow'.
            gmm_model (sklearn.mixture.GaussianMixture): The trained global GMM, required if density_model_type is 'gmm_remix'.
            adjacency_tol (float): Tolerance for considering two rule boundaries adjacent (for continuous features).
            iou_threshold (float): IoU threshold for considering other dimensions similar.
            density_jsd_threshold (float): JSD threshold for considering densities similar.
            verbose (bool): If True, prints information about merges.
            use_data_driven_bounds (bool): If True, uses the min/max of assigned data for IoU checks and merging.
            disabled_components (list): A list indicating which components are disabled.
        """
        if density_model_type == "gmm_remix" and gmm_model is None:
            raise ValueError("The 'gmm_model' argument must be provided for gmm_remix.")

        def get_log_prob_for_merge(flow, data):
            """Gets log probability from either a FlowTorch or Zuko model."""
            if isinstance(flow, dist.Flow):
                return flow.log_prob(data)
            else:  # Assuming zuko-style
                return flow().log_prob(data)

        n_features = self.rules[0].discretizer.cut_points.shape[0]
        device = self.rules[0].discretizer.cut_points.device
        did_merge_overall = False

        X.to(device)
        Y.to(device)
        while True:
            with torch.no_grad():
                full_rule_probs, _ = self(X)
                total_responsibilities = full_rule_probs[:, : self.n_components].sum(
                    dim=0
                )
                assignments = torch.argmax(full_rule_probs, dim=1)

            active_indices = [
                i for i, rule in enumerate(self.rules) if not rule.disabled
            ]

            if len(active_indices) < 2:
                if verbose:
                    print("Not enough active components to consider merging.")
                break

            data_bounds_map = {}
            if use_data_driven_bounds:
                for idx in active_indices:
                    data_bounds_map[idx] = _get_data_driven_bounds(X, assignments, idx)

            # Find all merge candidates
            merge_candidates = []

            for idx1, idx2 in combinations(active_indices, 2):
                rule1, rule2 = self.rules[idx1], self.rules[idx2]
                found_adjacent_dim = -1

                for k in range(n_features):  # k is the potential adjacent dimension
                    is_k_adjacent = False
                    is_k_discrete = rule1.discretizer.is_discrete[k]

                    bounds1 = data_bounds_map.get(idx1)
                    bounds2 = data_bounds_map.get(idx2)

                    if use_data_driven_bounds and (bounds1 is None or bounds2 is None):
                        continue

                    cuts1_k = (
                        bounds1[k]
                        if use_data_driven_bounds
                        else rule1.discretizer.cut_points[k, :, 0]
                    )
                    cuts2_k = (
                        bounds2[k]
                        if use_data_driven_bounds
                        else rule2.discretizer.cut_points[k, :, 0]
                    )

                    if is_k_discrete:
                        if _calculate_iou(cuts1_k, cuts2_k) == 0:
                            is_k_adjacent = True
                    else:
                        if (dist1 := abs(cuts1_k[1] - cuts2_k[0])) < adjacency_tol or (
                            dist2 := abs(cuts2_k[1] - cuts1_k[0])
                        ) < adjacency_tol:
                            is_k_adjacent = True

                    if not is_k_adjacent:
                        continue

                    similar_on_others = True
                    for i in range(n_features):
                        if i == k:
                            continue
                        if (
                            rule1.and_layer.and_weights.data[0, i] <= 0
                            or rule2.and_layer.and_weights.data[0, i] <= 0
                        ):
                            continue
                        cuts1_i = (
                            bounds1[i]
                            if use_data_driven_bounds
                            else rule1.discretizer.cut_points[i, :, 0]
                        )
                        cuts2_i = (
                            bounds2[i]
                            if use_data_driven_bounds
                            else rule2.discretizer.cut_points[i, :, 0]
                        )
                        if _calculate_iou(cuts1_i, cuts2_i) < iou_threshold:
                            similar_on_others = False
                            break

                    if similar_on_others:
                        found_adjacent_dim = k
                        break

                if found_adjacent_dim != -1:
                    if verbose:
                        print(
                            f"Found adjacent pair: Comp {idx1 + 1} and Comp {idx2 + 1} on feature {found_adjacent_dim}"
                        )
                    # This pair is adjacent. Now check density similarity.
                    y_min, y_max = Y.min(), Y.max()
                    grid_points = torch.linspace(
                        y_min, y_max, 200, device=device
                    ).unsqueeze(1)

                    if density_model_type == "gmm_remix":
                        grid_points_np = grid_points.cpu().numpy()
                        global_densities_on_grid_np = np.exp(
                            gmm_model._estimate_log_prob(grid_points_np)
                        )
                        global_densities_on_grid = torch.tensor(
                            global_densities_on_grid_np,
                            dtype=torch.float32,
                            device=device,
                        )
                        all_weights = density_model.get_mixing_weights()
                        w1, w2 = all_weights[:, idx1], all_weights[:, idx2]
                        p1_on_grid = global_densities_on_grid @ w1
                        p2_on_grid = global_densities_on_grid @ w2
                    elif density_model_type == "flow":
                        with torch.no_grad():
                            flow1 = density_model.component_flows[idx1]
                            flow2 = density_model.component_flows[idx2]
                            p1_on_grid = torch.exp(
                                get_log_prob_for_merge(flow1, grid_points)
                            )
                            p2_on_grid = torch.exp(
                                get_log_prob_for_merge(flow2, grid_points)
                            )

                    jsd = _js_divergence(
                        p1_on_grid / (p1_on_grid.sum() + 1e-10),
                        p2_on_grid / (p2_on_grid.sum() + 1e-10),
                    )

                    if jsd < density_jsd_threshold:
                        # It's a valid candidate, add to our list with its JSD score
                        merge_candidates.append((jsd, idx1, idx2, found_adjacent_dim))

            # Execute best merge
            merged_in_pass = False
            if merge_candidates:
                # Sort candidates by JSD (lower is better) and pick the best one
                merge_candidates.sort(key=lambda x: x[0])
                best_jsd, idx1, idx2, adjacent_dim = merge_candidates[0]

                absorb_idx, keep_idx = (
                    (idx2, idx1)
                    if total_responsibilities[idx1] > total_responsibilities[idx2]
                    else (idx1, idx2)
                )

                if verbose:
                    print(
                        f"Best merge candidate: Comp {absorb_idx + 1} into {keep_idx + 1} (adjacent on feat {adjacent_dim}, JSD: {best_jsd.item():.4f})"
                    )

                absorb_rule, keep_rule = self.rules[absorb_idx], self.rules[keep_idx]

                # Merge Rule Boundaries
                for k in range(n_features):
                    bounds_absorb = data_bounds_map.get(absorb_idx)
                    bounds_keep = data_bounds_map.get(keep_idx)
                    cuts_absorb = (
                        bounds_absorb[k]
                        if use_data_driven_bounds
                        else absorb_rule.discretizer.cut_points[k, :, 0]
                    )
                    cuts_keep = (
                        bounds_keep[k]
                        if use_data_driven_bounds
                        else keep_rule.discretizer.cut_points[k, :, 0]
                    )

                    new_lower = (
                        torch.min(cuts_absorb[0], cuts_keep[0])
                        if k == adjacent_dim
                        else torch.max(cuts_absorb[0], cuts_keep[0])
                    )
                    new_upper = (
                        torch.max(cuts_absorb[1], cuts_keep[1])
                        if k == adjacent_dim
                        else torch.min(cuts_absorb[1], cuts_keep[1])
                    )

                    keep_rule.discretizer.cut_points.data[k, 0, 0] = new_lower
                    keep_rule.discretizer.cut_points.data[k, 1, 0] = new_upper

                # Merge AND weights
                resp_absorb = total_responsibilities[absorb_idx]
                resp_keep = total_responsibilities[keep_idx]
                w_absorb = absorb_rule.and_layer.and_weights.data
                w_keep = keep_rule.and_layer.and_weights.data
                new_weights = (w_absorb * resp_absorb + w_keep * resp_keep) / (
                    resp_absorb + resp_keep + 1e-8
                )
                keep_rule.and_layer.and_weights.data = new_weights

                # Merge Experts
                if density_model_type == "gmm_remix":
                    w_absorb_expert = density_model.mixing_weights.data[:, absorb_idx]
                    w_keep_expert = density_model.mixing_weights.data[:, keep_idx]
                    new_expert_weights = (
                        w_absorb_expert * resp_absorb + w_keep_expert * resp_keep
                    ) / (resp_absorb + resp_keep + 1e-8)
                    density_model.mixing_weights.data[:, keep_idx] = new_expert_weights
                elif density_model_type == "flow":
                    density_model.disable_components([absorb_idx])
                    if verbose:
                        print(
                            f"  - Flow merge: Relying on settling phase to adapt flow {keep_idx + 1}."
                        )
                    pass

                absorb_rule.disabled = True
                if disabled_components is not None:
                    disabled_components[absorb_idx] = True

                merged_in_pass = True
                did_merge_overall = True

            if not merged_in_pass:
                if verbose:
                    print("No more mergeable components found in this pass.")
                break

        return did_merge_overall

    def merge_overlapping_components(
        self,
        X,
        density_model,
        density_model_type="gmm_remix",
        containment_threshold=0.9,
        verbose=True,
        disabled_components=None,
    ):
        """
        Merges rules where one largely overlaps or is contained within another.

        This method identifies pairs of rules where one rule's data-driven bounding
        box is almost entirely contained within another's. The smaller (contained)
        rule is then disabled, and its expert density is merged into the larger
        (containing) rule's expert. This is useful for cleaning up redundant rules
        that model different parts of the same underlying data region.

        Args:
            X (torch.Tensor): Input data to calculate responsibilities and bounds.
            density_model: The trained density model (e.g., GMMRemixer).
            density_model_type (str): The type of density model ('gmm_remix').
            containment_threshold (float): The fraction of a rule's volume that must be
                                           inside another to be considered for merging.
            verbose (bool): If True, prints information about merges.
            disabled_components (list, optional): A list/mask to track disabled components.

        Returns:
            bool: True if any merge was performed, False otherwise.
        """
        # if density_model_type != "gmm_remix":
        #     raise NotImplementedError(
        #         "Overlapping merge is only implemented for gmm_remix."
        #     )

        did_merge_overall = False
        X = X.to(next(self.parameters()).device)
        while True:
            with torch.no_grad():
                full_rule_probs, _ = self(X)
                total_responsibilities = full_rule_probs[:, : self.n_components].sum(
                    dim=0
                )
                assignments = torch.argmax(full_rule_probs, dim=1)

            active_indices = [
                i for i, rule in enumerate(self.rules) if not rule.disabled
            ]
            if len(active_indices) < 2:
                if verbose:
                    print("Not enough active components to consider merging.")
                break

            data_bounds_map = {
                idx: _get_data_driven_bounds(X, assignments, idx)
                for idx in active_indices
            }

            merge_candidates = []
            for idx1, idx2 in combinations(active_indices, 2):
                bounds1 = data_bounds_map.get(idx1)
                bounds2 = data_bounds_map.get(idx2)

                if bounds1 is None or bounds2 is None:
                    continue

                containment_1_in_2 = _calculate_containment(bounds2, bounds1)
                containment_2_in_1 = _calculate_containment(bounds1, bounds2)
                print(
                    f"Containment {idx1 + 1} in {idx2 + 1}: {containment_1_in_2:.2f}, {idx2 + 1} in {idx1 + 1}: {containment_2_in_1:.2f}"
                )

                if containment_1_in_2 > containment_threshold:
                    # Rule 1 is contained in Rule 2. Merge 1 into 2.
                    absorb_idx, keep_idx = idx1, idx2
                    merge_candidates.append((absorb_idx, keep_idx, containment_1_in_2))
                elif containment_2_in_1 > containment_threshold:
                    # Rule 2 is contained in Rule 1. Merge 2 into 1.
                    absorb_idx, keep_idx = idx2, idx1
                    merge_candidates.append((absorb_idx, keep_idx, containment_2_in_1))

            merged_in_pass = False
            if merge_candidates:
                # Sort by containment value, highest first, to merge most obvious cases
                merge_candidates.sort(key=lambda x: x[2], reverse=True)

                # To avoid chain merges in one pass, keep track of what's been touched
                merged_this_pass = set()

                for absorb_idx, keep_idx, containment_val in merge_candidates:
                    if absorb_idx in merged_this_pass or keep_idx in merged_this_pass:
                        continue

                    if verbose:
                        print(
                            f"Merging Comp {absorb_idx + 1} into Comp {keep_idx + 1} "
                            f"(containment: {containment_val.item():.2f})"
                        )

                    # Merge Experts (responsibility-weighted average)
                    if density_model_type == "gmm_remix":
                        resp_absorb = total_responsibilities[absorb_idx]
                        resp_keep = total_responsibilities[keep_idx]
                        total_resp = resp_absorb + resp_keep + 1e-8

                        w_absorb_expert = density_model.mixing_weights.data[
                            :, absorb_idx
                        ]
                        w_keep_expert = density_model.mixing_weights.data[:, keep_idx]

                        new_expert_weights = (
                            w_absorb_expert * resp_absorb + w_keep_expert * resp_keep
                        ) / total_resp
                        density_model.mixing_weights.data[:, keep_idx] = (
                            new_expert_weights
                        )

                    # Disable the absorbed rule
                    self.rules[absorb_idx].disabled = True
                    if disabled_components is not None:
                        disabled_components[absorb_idx] = True

                    merged_this_pass.add(absorb_idx)
                    merged_this_pass.add(keep_idx)
                    merged_in_pass = True
                    did_merge_overall = True

            if not merged_in_pass:
                if verbose and did_merge_overall:
                    print("No more overlapping components to merge.")
                break

        return did_merge_overall


class SimpleNeuralAndFinder(nn.Module):
    def __init__(
        self,
        data_limits,
        temperature=0.2,
        epsilon=0.001,
        initializer=None,
        discrete_features=None,
    ):
        super().__init__()
        if data_limits.dtype != torch.float32:
            data_limits = data_limits.to(torch.float32)

        n_features = data_limits.shape[0]
        self.limits = data_limits.detach().clone()

        if initializer is None:
            initializer = RandomInitializer()
        self.discretizer = Discretizing_Layer(
            n_features=n_features,
            predicates_per_feature=1,
            data_limits=data_limits,
            temperature=temperature,
            initializer=initializer,
            discrete_features=discrete_features,
        )

        self.and_layer = And_Layer(
            n_features=n_features,
            n_rules=1,
            epsilon=epsilon,
        )

        self.disabled = False

    def set_temperature(self, temperature):
        self.discretizer.temperature = temperature

    def forward(self, x):
        if x.dtype != torch.float32:
            x = x.to(torch.float32)

        if self.disabled:
            return torch.zeros(x.shape[0], device=x.device), torch.tensor(
                0.0, device=x.device
            )
        predicates = self.discretizer(x)

        rule_activation, l1_loss = self.and_layer(predicates)

        if rule_activation.dim() == 0:
            rule_activation = rule_activation.unsqueeze(0)

        return rule_activation, l1_loss

    def debug_print_cutpoints(self, feature_names=None, scaler=None):
        """
        Prints the cutpoints for this rule for debugging.
        This simplified method replaces the previous get_rules.

        Args:
            feature_names (list, optional): Names of the features.
            scaler (object, optional): A fitted scaler to show unscaled values.
        """
        cutpoints = self.discretizer.get_cutpoints(scaler=scaler)
        weights = self.and_layer.and_weights[0].detach().cpu().numpy()

        if feature_names is None:
            feature_names = [f"Feature_{i}" for i in range(len(weights))]

        output_lines = []
        for i in range(len(weights)):
            lower = cutpoints[i, 0, 0]
            upper = cutpoints[i, 1, 0]
            output_lines.append(
                f"   - {feature_names[i]:<15}: [{lower:3.2f}, {upper:3.2f}] (Weight: {weights[i]:.2f})"
            )
        return "\n".join(output_lines)

    def get_rule_string_for_csv(self, feature_names=None, scaler=None):
        """
        Generates a single-line string representation of the rule for CSV output.

        Args:
            feature_names (list, optional): Names of the features.
            scaler (object, optional): A fitted scaler to show unscaled values.

        Returns:
            str: A semicolon-separated string of the rule's predicates.
        """
        cutpoints = self.discretizer.get_cutpoints(scaler=scaler)
        weights = self.and_layer.and_weights[0].detach().cpu().numpy()

        if feature_names is None:
            feature_names = [f"Feature_{i}" for i in range(len(weights))]

        parts = []
        for i in range(len(weights)):
            lower = cutpoints[i, 0, 0]
            upper = cutpoints[i, 1, 0]
            part_str = (
                f"{feature_names[i]}=[{lower:.2f},{upper:.2f}],w={weights[i]:.2f}"
            )
            parts.append(part_str)
        return "; ".join(parts)

    def get_utilized_features(self, threshold=0.0):
        weights = self.and_layer.and_weights[0]
        return torch.where(weights > threshold)[0].tolist()

    def fix_parameters(self):
        if self.discretizer.fix_parameters():
            self.disabled = True
        if self.and_layer.fix_parameters():
            self.disabled = True

    def get_effective_intervals(self):
        """
        Returns a dictionary mapping feature indices to (lower, upper) interval bounds
        for actively used features with weight > 0.1
        """
        intervals = {}
        weights = self.and_layer.and_weights[0].detach()
        active_features = torch.where(weights > 0.1)[0]

        for i in active_features:
            lower = self.discretizer.cut_points[i, 0, 0].item()
            upper = self.discretizer.cut_points[i, 1, 0].item()
            intervals[i.item()] = (lower, upper)

            return intervals
