import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass

# Import the Component class
from ..mixture_gen import Component


@dataclass
class ComponentMetadata:
    """Metadata about a component's feature usage and properties."""

    component_id: int
    active_features: List[int]
    n_active_features: int
    active_feature_bounds: Dict[int, Tuple[float, float]]
    inactive_feature_bounds: Dict[int, Tuple[float, float]]
    distribution: str
    dist_params: Dict[str, Any]
    weight: float
    final_size: float
    placement_attempts: int
    placement_stage: str  # "original", "expanded", "reduced"


def sparse_distributed_component_generator(
    n_components: int,
    n_features: int,
    distributions: List[str],
    min_features_per_component: int = 2,
    max_features_per_component: Optional[int] = None,
    base_size: float = 0.5,
    spacing_factor: float = 1.5,
    placement_strategy: str = "random",  # "random", "grid", or "poisson_disk"
    adaptive_sizing: bool = True,
    vary_size: bool = False,
    vary_factor: float = 0.2,
    max_attempts: int = 1000,
    random_seed: int | None = None,
    domain_expansion_factor: float = 0.2,
    max_domain_expansion_steps: int = 5,
    noise_feature_bounds: Tuple[float, float] = (-2.0, 2.0),
    auto_scale_noise: bool = True,
    noise_scale_factor: float = 1.0,
    verbose: bool = False,
    return_metadata: bool = False,
) -> Tuple[List[Component], List[ComponentMetadata]] | List[Component]:
    """
    Generate components with sparse feature support distributed throughout feature space.

    Each component only uses a subset of features for its rules, making the model more
    realistic and testing the ability to ignore irrelevant dimensions.

    Args:
        n_components: Number of components to generate
        n_features: Total number of features available
        distributions: List of distribution types to use
        min_features_per_component: Minimum number of features each component uses
        max_features_per_component: Maximum number of features each component uses
                                  (defaults to n_features if None)
        base_size: Base size of each component in used dimensions
        spacing_factor: Controls minimum spacing between components (>1.0 ensures no overlap)
        placement_strategy: Strategy for placing components ("random", "grid", "poisson_disk")
        adaptive_sizing: Automatically adjust size based on number of components
        vary_size: Whether to vary the size of individual components
        vary_factor: How much to vary component sizes (if vary_size is True)
        max_attempts: Maximum attempts to place a component within a given stage
        random_seed: Seed for random number generation (for reproducibility)
        domain_expansion_factor: Factor by which to expand the domain for random center generation
        max_domain_expansion_steps: Number of steps for domain expansion
        noise_feature_bounds: Bounds for unused features (wide range for noise sampling)
        auto_scale_noise: Automatically scale noise bounds based on informative feature ranges
        noise_scale_factor: Scale factor for noise range relative to informative range
        verbose: Print verbose output

    Returns:
        Tuple of (List of Component objects, List of ComponentMetadata objects)
    """
    components = []
    metadata_list = []

    # Set random seed if provided
    if random_seed is not None:
        np.random.seed(random_seed)

    # Validate parameters
    if max_features_per_component is None:
        max_features_per_component = n_features

    min_features_per_component = max(1, min(min_features_per_component, n_features))
    max_features_per_component = max(
        min_features_per_component, min(max_features_per_component, n_features)
    )

    # Adapt base size if needed
    if adaptive_sizing:
        # Smaller components when we have more of them
        adapted_size = base_size * (1.0 - 0.08 * (n_components - 2))
        base_size = max(0.1, min(0.8, adapted_size))
    else:
        base_size = min(max(0.1, base_size), 0.8)

    def components_overlap(comp1: Dict, comp2: Dict) -> bool:
        """Check if two components overlap, considering only their active features."""
        # Get the features that both components actually use
        active_features_1 = set(comp1["active_features"])
        active_features_2 = set(comp2["active_features"])

        # They can only overlap in dimensions where both have rules
        common_features = active_features_1.intersection(active_features_2)

        if not common_features:
            # No common active features means no possibility of overlap
            return False

        # Check overlap in common dimensions
        for dim in common_features:
            min1, max1 = comp1["rules"][dim]
            min2, max2 = comp2["rules"][dim]
            # If there's a gap in this dimension, the components don't overlap
            if max1 < min2 or max2 < min1:
                return False

        # If we haven't found a gap in any common dimension, they overlap
        return True

    # Generate component sizes
    component_sizes = []
    for i in range(n_components):
        if vary_size:
            size = base_size * (1.0 + vary_factor * (np.random.random() - 0.5))
        else:
            size = base_size
        component_sizes.append(size)

    # Pre-generate feature selections for each component
    component_feature_specs = []
    for i in range(n_components):
        # Sample number of features this component will use
        n_features_for_component = np.random.randint(
            min_features_per_component, max_features_per_component + 1
        )

        # Sample which specific features to use
        active_features = np.random.choice(
            n_features, size=n_features_for_component, replace=False
        ).tolist()

        component_feature_specs.append(
            {"n_active": n_features_for_component, "active_features": active_features}
        )

        if verbose:
            print(
                f"Component {i+1}: will use {n_features_for_component} features: {active_features}"
            )

    # Generate locations based on strategy (only for active dimensions)
    if placement_strategy == "grid":
        # Determine grid dimensions based on average number of active features
        avg_active_features = np.mean(
            [spec["n_active"] for spec in component_feature_specs]
        )
        grid_dim = int(np.ceil(n_components ** (1 / avg_active_features)))
        grid_spacing = 1.0 / grid_dim

        # Pre-generate all potential grid positions with jitter
        grid_positions = []
        for idx in range(min(grid_dim**n_features, n_components * 3)):
            cell_indices = np.unravel_index(idx, [grid_dim] * n_features)
            base_pos = [(idx + 0.5) * grid_spacing for idx in cell_indices]
            # Add jitter (up to 30% of grid spacing)
            jitter = [
                (np.random.random() - 0.5) * 0.3 * grid_spacing
                for _ in range(n_features)
            ]
            position = [min(max(p + j, 0.1), 0.9) for p, j in zip(base_pos, jitter)]
            grid_positions.append(position)

        # Shuffle positions to avoid predictable patterns
        np.random.shuffle(grid_positions)

    elif placement_strategy == "poisson_disk":
        # Simple Poisson disk sampling implementation
        positions = [np.random.random(n_features)]
        active_list = [0]
        min_dist = base_size * spacing_factor

        while active_list and len(positions) < n_components * 2:
            idx = np.random.choice(active_list)
            current_pos = positions[idx]

            # Try to place a new point around the current one
            for _ in range(30):
                direction = np.random.randn(n_features)
                direction = direction / np.linalg.norm(direction)
                distance = min_dist * (1 + 0.5 * np.random.random())
                new_pos = current_pos + direction * distance

                # Check if new position is valid
                if np.all(new_pos >= 0.1) and np.all(new_pos <= 0.9):
                    valid = True
                    for p in positions:
                        if np.linalg.norm(new_pos - p) < min_dist:
                            valid = False
                            break
                    if valid:
                        positions.append(new_pos)
                        active_list.append(len(positions) - 1)
                        break

            # If we couldn't place a new point, remove from active list
            if _ == 29:
                active_list.remove(idx)

        grid_positions = positions

    else:  # random placement
        grid_positions = []

    # Place components
    placed_components = []
    for i in range(n_components):
        original_component_size = component_sizes[i]
        feature_spec = component_feature_specs[i]
        active_features = feature_spec["active_features"]

        # Select distribution
        dist = np.random.choice(distributions)

        # Define distribution parameters based on the distribution type
        if dist == "normal":
            dist_params = {"loc": 0.0, "scale": 0.1}
        elif dist == "uniform":
            dist_params = {"low": -0.5, "high": 0.5}
        elif dist == "gamma":
            dist_params = {"shape": 2.0, "scale": 0.2}
        elif dist == "exponential":
            dist_params = {"scale": 0.3}

        # Equal weights for all components
        weight = 1.0 / n_components
        component_placed = False
        placement_attempts = 0
        placement_stage = "original"
        final_size = original_component_size

        # --- Stage 1: Try original size, original domain ---
        active_domain_min, active_domain_max = 0.0, 1.0
        current_attempt_size = original_component_size
        half_current_attempt_size = current_attempt_size / 2

        for attempt_num in range(max_attempts):
            placement_attempts += 1
            # Get position based on strategy
            if (
                (placement_strategy == "grid" or placement_strategy == "poisson_disk")
                and attempt_num == 0
                and i < len(grid_positions)
            ):
                center = grid_positions[i]
            else:  # Random placement
                center = [
                    active_domain_min
                    + (active_domain_max - active_domain_min) * np.random.random()
                    for _ in range(n_features)
                ]

            # Create rules only for active features
            rules = {}
            for dim in range(n_features):
                if dim in active_features:
                    # Create tight rule for active features
                    rules[dim] = (
                        center[dim] - half_current_attempt_size,
                        center[dim] + half_current_attempt_size,
                    )
                else:
                    # Create very wide rule for inactive features (essentially no constraint)
                    rules[dim] = noise_feature_bounds

            candidate = {"rules": rules, "active_features": active_features}

            # Check for overlaps with previously placed components
            overlaps = False
            for placed in placed_components:
                if components_overlap(candidate, placed):
                    overlaps = True
                    break

            if not overlaps:
                placed_components.append(candidate)
                final_size = current_attempt_size
                component_placed = True
                if verbose:
                    print(
                        f"  ✓ Component {i+1}/{n_components} placed in original domain."
                    )
                break

        if component_placed:
            # Create component and metadata
            components.append(
                Component(
                    rules=rules,
                    distribution=dist,
                    dist_params=dist_params,
                    weight=weight,
                )
            )

            active_bounds = {dim: rules[dim] for dim in active_features}
            inactive_bounds = {
                dim: rules[dim]
                for dim in range(n_features)
                if dim not in active_features
            }

            metadata_list.append(
                ComponentMetadata(
                    component_id=i,
                    active_features=active_features,
                    n_active_features=len(active_features),
                    active_feature_bounds=active_bounds,
                    inactive_feature_bounds=inactive_bounds,
                    distribution=dist,
                    dist_params=dist_params,
                    weight=weight,
                    final_size=final_size,
                    placement_attempts=placement_attempts,
                    placement_stage=placement_stage,
                )
            )
            continue

        # --- Stage 2: Try original size, expanded domain ---
        if verbose:
            print(
                f"Info: Component {i+1}/{n_components} not placed in original domain. Trying domain expansion."
            )

        current_attempt_size = original_component_size
        half_current_attempt_size = current_attempt_size / 2
        placement_stage = "expanded"

        for expansion_step in range(max_domain_expansion_steps):
            active_domain_min = 0.0 - (expansion_step + 1) * domain_expansion_factor
            active_domain_max = 1.0 + (expansion_step + 1) * domain_expansion_factor
            if verbose:
                print(
                    f"  - Attempting domain [{active_domain_min:.2f}, {active_domain_max:.2f}] for component {i+1}/{n_components}"
                )

            for attempt_num in range(max_attempts):
                placement_attempts += 1
                if (
                    (
                        placement_strategy == "grid"
                        or placement_strategy == "poisson_disk"
                    )
                    and attempt_num == 0
                    and i < len(grid_positions)
                ):
                    center = grid_positions[i]
                else:
                    center = [
                        active_domain_min
                        + (active_domain_max - active_domain_min) * np.random.random()
                        for _ in range(n_features)
                    ]

                # Create rules only for active features
                rules = {}
                for dim in range(n_features):
                    if dim in active_features:
                        rules[dim] = (
                            center[dim] - half_current_attempt_size,
                            center[dim] + half_current_attempt_size,
                        )
                    else:
                        rules[dim] = noise_feature_bounds

                candidate = {"rules": rules, "active_features": active_features}

                overlaps = False
                for placed in placed_components:
                    if components_overlap(candidate, placed):
                        overlaps = True
                        break

                if not overlaps:
                    placed_components.append(candidate)
                    final_size = current_attempt_size
                    component_placed = True
                    if verbose:
                        print(
                            f"  ✓ Component {i+1}/{n_components} placed with domain expansion."
                        )
                    break
            if component_placed:
                break

        if component_placed:
            # Create component and metadata
            components.append(
                Component(
                    rules=rules,
                    distribution=dist,
                    dist_params=dist_params,
                    weight=weight,
                )
            )

            active_bounds = {dim: rules[dim] for dim in active_features}
            inactive_bounds = {
                dim: rules[dim]
                for dim in range(n_features)
                if dim not in active_features
            }

            metadata_list.append(
                ComponentMetadata(
                    component_id=i,
                    active_features=active_features,
                    n_active_features=len(active_features),
                    active_feature_bounds=active_bounds,
                    inactive_feature_bounds=inactive_bounds,
                    distribution=dist,
                    dist_params=dist_params,
                    weight=weight,
                    final_size=final_size,
                    placement_attempts=placement_attempts,
                    placement_stage=placement_stage,
                )
            )
            continue

        # --- Stage 3: Try reduced size, last active domain ---
        if verbose:
            print(
                f"Warning: Component {i+1}/{n_components} not placed after domain expansion. Trying with reduced sizes."
            )

        current_reduced_size = original_component_size
        reduction_factor = 0.8
        min_allowable_size = 0.2 * original_component_size
        placement_stage = "reduced"

        if current_reduced_size < min_allowable_size:
            min_allowable_size = current_reduced_size * reduction_factor

        while current_reduced_size > min_allowable_size and not component_placed:
            current_reduced_size *= reduction_factor
            if current_reduced_size < min_allowable_size:
                current_reduced_size = min_allowable_size

            half_current_reduced_size = current_reduced_size / 2
            if verbose:
                print(
                    f"  - Attempting with size reduced to {current_reduced_size/original_component_size:.1%} of original for component {i+1}/{n_components}"
                )

            for attempt_num in range(max_attempts):
                placement_attempts += 1
                if (
                    (
                        placement_strategy == "grid"
                        or placement_strategy == "poisson_disk"
                    )
                    and attempt_num == 0
                    and i < len(grid_positions)
                ):
                    center = grid_positions[i]
                else:
                    center = [
                        active_domain_min
                        + (active_domain_max - active_domain_min) * np.random.random()
                        for _ in range(n_features)
                    ]

                # Create rules only for active features
                rules = {}
                for dim in range(n_features):
                    if dim in active_features:
                        rules[dim] = (
                            center[dim] - half_current_reduced_size,
                            center[dim] + half_current_reduced_size,
                        )
                    else:
                        rules[dim] = noise_feature_bounds

                candidate = {"rules": rules, "active_features": active_features}

                overlaps = False
                for placed in placed_components:
                    if components_overlap(candidate, placed):
                        overlaps = True
                        break

                if not overlaps:
                    placed_components.append(candidate)
                    final_size = current_reduced_size
                    component_placed = True
                    if verbose:
                        print(
                            f"  ✓ Successfully placed component {i+1}/{n_components} at {current_reduced_size/original_component_size:.1%} of original size."
                        )
                    break

            if component_placed:
                break

            if current_reduced_size == min_allowable_size:
                break

        if component_placed:
            # Create component and metadata
            components.append(
                Component(
                    rules=rules,
                    distribution=dist,
                    dist_params=dist_params,
                    weight=weight,
                )
            )

            active_bounds = {dim: rules[dim] for dim in active_features}
            inactive_bounds = {
                dim: rules[dim]
                for dim in range(n_features)
                if dim not in active_features
            }

            metadata_list.append(
                ComponentMetadata(
                    component_id=i,
                    active_features=active_features,
                    n_active_features=len(active_features),
                    active_feature_bounds=active_bounds,
                    inactive_feature_bounds=inactive_bounds,
                    distribution=dist,
                    dist_params=dist_params,
                    weight=weight,
                    final_size=final_size,
                    placement_attempts=placement_attempts,
                    placement_stage=placement_stage,
                )
            )
        else:
            if verbose:
                print(
                    f"Critical: Failed to place component {i+1}/{n_components} even with progressive domain expansion and size reduction."
                )

    # Auto-scale noise feature bounds based on informative feature ranges
    if auto_scale_noise and components:
        if verbose:
            print("\nAuto-scaling noise feature bounds...")

        # Calculate the range of all informative features across all components
        all_informative_bounds = []
        for comp_idx, component in enumerate(components):
            metadata = metadata_list[comp_idx]
            for feat_idx in metadata.active_features:
                bounds = component.rules[feat_idx]
                all_informative_bounds.extend(bounds)

        if all_informative_bounds:
            # Calculate the overall range of informative features
            min_informative = min(all_informative_bounds)
            max_informative = max(all_informative_bounds)
            informative_range = max_informative - min_informative

            # Calculate new noise bounds centered around 0 with scaled range
            noise_range = informative_range * noise_scale_factor
            new_noise_bounds = (-noise_range / 2, noise_range / 2)

            if verbose:
                print(
                    f"  Informative range: [{min_informative:.3f}, {max_informative:.3f}] (span: {informative_range:.3f})"
                )
                print(f"  Old noise bounds: {noise_feature_bounds}")
                print(
                    f"  New noise bounds: [{new_noise_bounds[0]:.3f}, {new_noise_bounds[1]:.3f}]"
                )

            # Update noise feature bounds in all components
            for comp_idx, component in enumerate(components):
                metadata = metadata_list[comp_idx]
                for feat_idx in range(n_features):
                    if feat_idx not in metadata.active_features:
                        # Update inactive feature bounds
                        component.rules[feat_idx] = new_noise_bounds
                        # Update metadata
                        metadata.inactive_feature_bounds[feat_idx] = new_noise_bounds

    if verbose:
        print("\nComponent Summary:")
        for metadata in metadata_list:
            print(
                f"  Component {metadata.component_id+1}: uses {metadata.n_active_features} features {metadata.active_features}, "
                f"placed in {metadata.placement_stage} stage after {metadata.placement_attempts} attempts"
            )

    # if return_metadata:
    # return components, metadata_list
    # else:
    # return components
    return components, metadata_list


def print_component_summary(
    components: List[Component], metadata_list: List[ComponentMetadata]
):
    """Print a detailed summary of generated components."""
    print(f"\n{'='*60}")
    print("SPARSE COMPONENT GENERATOR SUMMARY")
    print(f"{'='*60}")
    print(f"Total components generated: {len(components)}")

    if not metadata_list:
        print("No metadata available.")
        return

    # Feature usage statistics
    all_active_features = []
    for meta in metadata_list:
        all_active_features.extend(meta.active_features)

    unique_features = set(all_active_features)
    feature_counts = {f: all_active_features.count(f) for f in unique_features}

    print("\nFeature Usage Statistics:")
    print(f"  Features used: {sorted(unique_features)}")
    print(f"  Feature usage counts: {dict(sorted(feature_counts.items()))}")

    n_features_per_comp = [meta.n_active_features for meta in metadata_list]
    print(
        f"  Features per component: min={min(n_features_per_comp)}, max={max(n_features_per_comp)}, avg={np.mean(n_features_per_comp):.1f}"
    )

    # Placement statistics
    placement_stages = [meta.placement_stage for meta in metadata_list]
    stage_counts = {
        stage: placement_stages.count(stage) for stage in set(placement_stages)
    }
    print("\nPlacement Statistics:")
    print(f"  Placement stages: {stage_counts}")

    total_attempts = sum(meta.placement_attempts for meta in metadata_list)
    avg_attempts = total_attempts / len(metadata_list)
    print(f"  Average placement attempts: {avg_attempts:.1f}")

    # Size statistics
    sizes = [meta.final_size for meta in metadata_list]
    print(
        f"  Component sizes: min={min(sizes):.3f}, max={max(sizes):.3f}, avg={np.mean(sizes):.3f}"
    )

    # Distribution usage
    distributions = [meta.distribution for meta in metadata_list]
    dist_counts = {dist: distributions.count(dist) for dist in set(distributions)}
    print(f"  Distribution usage: {dist_counts}")

    print("\nDetailed Component Information:")
    print(
        f"{'ID':<3} {'Active Features':<20} {'Distribution':<12} {'Size':<8} {'Stage':<10} {'Attempts':<8}"
    )
    print(f"{'-'*70}")

    for meta in metadata_list:
        features_str = (
            str(meta.active_features)[:18] + ".."
            if len(str(meta.active_features)) > 20
            else str(meta.active_features)
        )
        print(
            f"{meta.component_id:<3} {features_str:<20} {meta.distribution:<12} {meta.final_size:<8.3f} {meta.placement_stage:<10} {meta.placement_attempts:<8}"
        )


def visualize_components_2d(
    components: List[Component],
    metadata_list: List[ComponentMetadata],
    feature_pair: Tuple[int, int] = (0, 1),
    figsize: Tuple[int, int] = (12, 8),
):
    """Visualize components in 2D for a specific pair of features."""

    if not components:
        print("No components to visualize.")
        return

    feat_x, feat_y = feature_pair

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Plot 1: All components with their active regions
    ax1.set_title(f"Component Active Regions (Features {feat_x} vs {feat_y})")
    ax1.set_xlabel(f"Feature {feat_x}")
    ax1.set_ylabel(f"Feature {feat_y}")

    colors = plt.cm.Set3(np.linspace(0, 1, len(components)))

    for i, (comp, meta) in enumerate(zip(components, metadata_list)):
        color = colors[i]

        # Check if this component uses both features
        uses_x = feat_x in meta.active_features
        uses_y = feat_y in meta.active_features

        if uses_x and uses_y:
            # Draw rectangle for active region
            x_min, x_max = comp.rules[feat_x]
            y_min, y_max = comp.rules[feat_y]

            rect = plt.Rectangle(
                (x_min, y_min),
                x_max - x_min,
                y_max - y_min,
                fill=True,
                alpha=0.3,
                color=color,
                edgecolor="black",
                linewidth=1,
            )
            ax1.add_patch(rect)

            # Add label
            center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2
            ax1.text(
                center_x,
                center_y,
                f"C{i}",
                ha="center",
                va="center",
                fontweight="bold",
                fontsize=8,
            )

        elif uses_x:
            # Component uses feature X but not Y - draw vertical band
            x_min, x_max = comp.rules[feat_x]
            ax1.axvspan(x_min, x_max, alpha=0.2, color=color, label=f"C{i} (X only)")

        elif uses_y:
            # Component uses feature Y but not X - draw horizontal band
            y_min, y_max = comp.rules[feat_y]
            ax1.axhspan(y_min, y_max, alpha=0.2, color=color, label=f"C{i} (Y only)")
        else:
            # Component doesn't use either feature - note in legend
            ax1.plot(
                [], [], color=color, alpha=0.5, linewidth=5, label=f"C{i} (neither)"
            )

    ax1.grid(True, alpha=0.3)
    ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    # Plot 2: Feature usage heatmap
    ax2.set_title("Feature Usage Heatmap")

    if metadata_list:
        n_features = max(max(meta.active_features) for meta in metadata_list) + 1
        usage_matrix = np.zeros((len(components), n_features))

        for i, meta in enumerate(metadata_list):
            for feat in meta.active_features:
                usage_matrix[i, feat] = 1

        im = ax2.imshow(usage_matrix, cmap="Blues", aspect="auto")
        ax2.set_xlabel("Feature Index")
        ax2.set_ylabel("Component Index")
        ax2.set_xticks(range(n_features))
        ax2.set_yticks(range(len(components)))

        # Add text annotations
        for i in range(len(components)):
            for j in range(n_features):
                text = ax2.text(
                    j,
                    i,
                    "✓" if usage_matrix[i, j] else "",
                    ha="center",
                    va="center",
                    color="red" if usage_matrix[i, j] else "lightgray",
                )

        plt.colorbar(im, ax=ax2, label="Feature Used")

    plt.tight_layout()
    plt.show()


def validate_component_properties(
    components: List[Component], metadata_list: List[ComponentMetadata]
) -> Dict[str, Any]:
    """Validate that components have expected properties."""

    validation_results = {
        "total_components": len(components),
        "total_metadata": len(metadata_list),
        "metadata_matches_components": len(components) == len(metadata_list),
        "overlapping_pairs": [],
        "feature_coverage": {},
        "size_distribution": {},
        "placement_success_rate": 0.0,
    }

    if not components:
        return validation_results

    # Check for overlaps in active dimensions
    overlapping_pairs = []
    for i in range(len(components)):
        for j in range(i + 1, len(components)):
            meta_i, meta_j = metadata_list[i], metadata_list[j]

            # Find common active features
            common_features = set(meta_i.active_features).intersection(
                set(meta_j.active_features)
            )

            if common_features:
                # Check if they actually overlap in these dimensions
                overlaps = True
                for feat in common_features:
                    min1, max1 = components[i].rules[feat]
                    min2, max2 = components[j].rules[feat]
                    if max1 < min2 or max2 < min1:
                        overlaps = False
                        break

                if overlaps:
                    overlapping_pairs.append((i, j, list(common_features)))

    validation_results["overlapping_pairs"] = overlapping_pairs

    # Feature coverage analysis
    if metadata_list:
        all_features = set()
        for meta in metadata_list:
            all_features.update(meta.active_features)

        feature_counts = {}
        for feat in all_features:
            count = sum(1 for meta in metadata_list if feat in meta.active_features)
            feature_counts[feat] = count

        validation_results["feature_coverage"] = {
            "total_features_used": len(all_features),
            "feature_usage_counts": feature_counts,
            "unused_features": [],  # We don't know total n_features here
            "features_used_once": [f for f, c in feature_counts.items() if c == 1],
            "features_used_multiple": [f for f, c in feature_counts.items() if c > 1],
        }

        # Size distribution
        sizes = [meta.final_size for meta in metadata_list]
        validation_results["size_distribution"] = {
            "min": min(sizes),
            "max": max(sizes),
            "mean": np.mean(sizes),
            "std": np.std(sizes),
        }

        # Placement success rate
        successful_placements = len(
            [meta for meta in metadata_list]
        )  # All in list were successful
        validation_results["placement_success_rate"] = successful_placements / len(
            metadata_list
        )

    return validation_results


def generate_data_from_sparse_components(
    components: List[Component],
    metadata_list: List[ComponentMetadata],
    n_samples: int = 1000,
    random_seed: int | None = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate data points from sparse components.

    Args:
        components: List of Component objects
        metadata_list: List of ComponentMetadata objects
        n_samples: Total number of samples to generate
        random_seed: Seed for random number generation

    Returns:
        Tuple of (X, y) where:
        - X: np.ndarray of shape (n_samples, n_features) containing the feature values
        - y: np.ndarray of shape (n_samples,) containing component labels
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    if not components or not metadata_list:
        raise ValueError("No components provided")

    # Determine number of features from the first component
    n_features = len(components[0].rules)

    # Allocate samples proportionally to component weights
    component_weights = np.array([comp.weight for comp in components])
    component_weights = component_weights / component_weights.sum()  # Normalize

    samples_per_component = np.random.multinomial(n_samples, component_weights)

    X_list = []
    y_list = []

    for comp_idx, (component, metadata, n_comp_samples) in enumerate(
        zip(components, metadata_list, samples_per_component)
    ):
        if n_comp_samples == 0:
            continue

        # Initialize samples for this component
        X_comp = np.zeros((n_comp_samples, n_features))

        # Generate samples for each feature
        for feat_idx in range(n_features):
            if feat_idx in metadata.active_features:
                # Active feature: sample from component's distribution within rule bounds
                rule_min, rule_max = component.rules[feat_idx]

                if component.distribution == "normal":
                    loc = component.dist_params.get("loc", 0.0)
                    scale = component.dist_params.get("scale", 0.1)

                    # Sample from truncated normal distribution
                    samples = np.random.normal(loc, scale, n_comp_samples)
                    # Clip to rule bounds
                    samples = np.clip(samples, rule_min, rule_max)

                elif component.distribution == "uniform":
                    low = component.dist_params.get("low", -0.5)
                    high = component.dist_params.get("high", 0.5)

                    # Scale uniform distribution to rule bounds
                    samples = np.random.uniform(rule_min, rule_max, n_comp_samples)

                elif component.distribution == "gamma":
                    shape = component.dist_params.get("shape", 2.0)
                    scale = component.dist_params.get("scale", 0.2)

                    # Sample from gamma and scale to rule bounds
                    samples = np.random.gamma(shape, scale, n_comp_samples)
                    # Normalize to [0,1] approximately and then scale to bounds
                    samples = samples / (shape * scale * 3)  # Rough normalization
                    samples = rule_min + (rule_max - rule_min) * np.clip(samples, 0, 1)

                elif component.distribution == "exponential":
                    scale = component.dist_params.get("scale", 0.3)

                    # Sample from exponential and scale to rule bounds
                    samples = np.random.exponential(scale, n_comp_samples)
                    # Normalize and scale to bounds
                    samples = samples / (scale * 3)  # Rough normalization
                    samples = rule_min + (rule_max - rule_min) * np.clip(samples, 0, 1)

                else:
                    # Fallback to uniform in rule bounds
                    samples = np.random.uniform(rule_min, rule_max, n_comp_samples)

                X_comp[:, feat_idx] = samples

            else:
                # Inactive feature: sample from noise distribution
                rule_min, rule_max = component.rules[
                    feat_idx
                ]  # Should be noise_feature_bounds
                samples = np.random.uniform(rule_min, rule_max, n_comp_samples)
                X_comp[:, feat_idx] = samples

        X_list.append(X_comp)
        y_list.append(np.full(n_comp_samples, comp_idx))

    # Concatenate all samples
    if X_list:
        X = np.vstack(X_list)
        y = np.concatenate(y_list)

        # Shuffle the data
        shuffle_idx = np.random.permutation(len(X))
        X = X[shuffle_idx]
        y = y[shuffle_idx]
    else:
        X = np.empty((0, n_features))
        y = np.empty(0, dtype=int)

    return X, y


def test_data_generation():
    """Test the data generation function."""
    print("Testing Data Generation from Sparse Components")
    print("=" * 50)

    # Generate some components
    components, metadata = sparse_distributed_component_generator(
        n_components=3,
        n_features=4,
        distributions=["normal", "uniform"],
        min_features_per_component=2,
        max_features_per_component=3,
        random_seed=42,
        verbose=True,
    )

    # Generate data
    X, y = generate_data_from_sparse_components(
        components, metadata, n_samples=500, random_seed=123
    )

    print("\nGenerated Data:")
    print(f"  Data shape: {X.shape}")
    print(f"  Labels shape: {y.shape}")
    print(f"  Label distribution: {np.bincount(y)}")
    print("  Feature ranges:")
    for i in range(X.shape[1]):
        print(f"    Feature {i}: [{X[:, i].min():.3f}, {X[:, i].max():.3f}]")

    # Check that data respects component rules
    print("\nValidating data against component rules:")
    for comp_idx in range(len(components)):
        comp_data = X[y == comp_idx]
        meta = metadata[comp_idx]
        print(f"  Component {comp_idx} ({len(comp_data)} samples):")

        for feat_idx in meta.active_features:
            rule_min, rule_max = components[comp_idx].rules[feat_idx]
            feat_data = comp_data[:, feat_idx]
            in_bounds = np.all((feat_data >= rule_min) & (feat_data <= rule_max))
            print(
                f"    Feature {feat_idx}: rule=[{rule_min:.3f}, {rule_max:.3f}], "
                f"data=[{feat_data.min():.3f}, {feat_data.max():.3f}], in_bounds={in_bounds}"
            )


def test_sparse_generator():
    """Comprehensive test of the sparse distributed component generator."""

    print("Testing Sparse Distributed Component Generator")
    print("=" * 50)

    # Test configuration
    test_configs = [
        {
            "name": "Basic Test",
            "n_components": 4,
            "n_features": 6,
            "min_features_per_component": 2,
            "max_features_per_component": 4,
            "distributions": ["normal", "uniform"],
            "random_seed": 42,
            "verbose": True,
        },
        {
            "name": "High Sparsity Test",
            "n_components": 5,
            "n_features": 10,
            "min_features_per_component": 1,
            "max_features_per_component": 3,
            "distributions": ["normal", "gamma", "exponential"],
            "random_seed": 123,
            "verbose": False,
        },
        {
            "name": "Low Sparsity Test",
            "n_components": 3,
            "n_features": 5,
            "min_features_per_component": 3,
            "max_features_per_component": 5,
            "distributions": ["uniform", "normal"],
            "random_seed": 456,
            "verbose": False,
        },
    ]

    # Run tests
    for i, config in enumerate(test_configs):
        print(f"\n{'-'*20} {config['name']} {'-'*20}")

        try:
            # Generate components (exclude 'name' from config)
            generator_config = {k: v for k, v in config.items() if k != "name"}
            components, metadata = sparse_distributed_component_generator(
                **generator_config
            )

            # Print summary
            print_component_summary(components, metadata)

            # Validate properties
            validation = validate_component_properties(components, metadata)
            print("\nValidation Results:")
            print(f"  Components generated: {validation['total_components']}")
            print(f"  Metadata entries: {validation['total_metadata']}")
            print(f"  Overlapping pairs: {len(validation['overlapping_pairs'])}")
            if validation["overlapping_pairs"]:
                print(f"    {validation['overlapping_pairs']}")
            print(
                f"  Features used: {validation['feature_coverage']['total_features_used']}"
            )
            print(
                f"  Feature usage: {validation['feature_coverage']['feature_usage_counts']}"
            )

            # Create visualization for first test only
            if i == 0:
                print(f"\nCreating visualization for {config['name']}...")
                visualize_components_2d(components, metadata, feature_pair=(0, 1))

        except Exception as e:
            print(f"Error in {config['name']}: {e}")
            import traceback

            traceback.print_exc()

    print(f"\n{'='*60}")
    print("Testing completed!")


def quick_demo(
    n_components: int = 3,
    n_features: int = 5,
    min_features_per_component: int = 2,
    max_features_per_component: int = 4,
    show_plot: bool = True,
    verbose: bool = True,
):
    """
    Quick demo function for interactive testing of the sparse generator.

    Args:
        n_components: Number of components to generate
        n_features: Total number of features
        min_features_per_component: Min features per component
        max_features_per_component: Max features per component
        show_plot: Whether to show visualization
        verbose: Whether to print detailed output
    """
    print(f"Quick Demo: {n_components} components, {n_features} features")
    print(
        f"Feature sparsity: {min_features_per_component}-{max_features_per_component} features per component"
    )
    print("=" * 60)

    try:
        # Generate components
        components, metadata = sparse_distributed_component_generator(
            n_components=n_components,
            n_features=n_features,
            distributions=["normal", "uniform", "gamma"],
            min_features_per_component=min_features_per_component,
            max_features_per_component=max_features_per_component,
            base_size=0.4,
            random_seed=42,
            verbose=verbose,
        )

        # Print summary
        print_component_summary(components, metadata)

        # Show validation
        validation = validate_component_properties(components, metadata)
        print("\nQuick Validation:")
        print(f"  ✓ {validation['total_components']} components generated successfully")
        print(
            f"  ✓ {validation['feature_coverage']['total_features_used']} unique features used"
        )
        print(f"  ✓ {len(validation['overlapping_pairs'])} overlapping pairs")

        # Show plot if requested
        if show_plot and n_features >= 2:
            print("\nShowing 2D visualization (features 0 vs 1)...")
            visualize_components_2d(components, metadata, feature_pair=(0, 1))

        return components, metadata

    except Exception as e:
        print(f"Error in demo: {e}")
        import traceback

        traceback.print_exc()
        return None, None


if __name__ == "__main__":
    # Run a quick demo by default, or full tests if argument provided
    import sys

    if len(sys.argv) > 1 and sys.argv[1] == "test":
        # Run the comprehensive test when this file is executed with "test" argument
        test_sparse_generator()
    elif len(sys.argv) > 1 and sys.argv[1] == "data":
        # Test data generation
        test_data_generation()
    else:
        # Run quick demo by default
        print(
            "Running quick demo (use 'python sparse_distributed.py test' for full tests)"
        )
        print("Use 'python sparse_distributed.py data' to test data generation")
        quick_demo()
