import numpy as np
from typing import List
from ..mixture_gen import Component


def distributed_component_generator(
    n_components: int,
    n_features: int,
    distributions: List[str],
    base_size: float = 0.5,
    spacing_factor: float = 1.5,
    placement_strategy: str = "random",  # "random", "grid", or "poisson_disk"
    adaptive_sizing: bool = True,
    vary_size: bool = False,
    vary_factor: float = 0.2,
    max_attempts: int = 1000,
    random_seed: int | None = None,
    domain_expansion_factor: float = 0.2,
    max_domain_expansion_steps: int = 5,
    verbose=False,
) -> List[Component]:
    """
    Generate components distributed throughout feature space without overlapping.

    This function places components by first attempting placement in the original domain,
    then trying to expand the domain for random placements if initial attempts fail,
    and finally attempting to reduce component size if placement is still unsuccessful.
    Component boundaries are not clipped to a [0,1] hypercube.

    Args:
        n_components: Number of components to generate
        n_features: Number of features
        distributions: List of distribution types to use
        base_size: Base size of each component (initially aimed for [0.1, 0.9] range)
        spacing_factor: Controls minimum spacing between components (>1.0 ensures no overlap for Poisson disk)
        placement_strategy: Strategy for placing components:
                          - "random": Random placement with no-overlap check
                          - "grid": Grid-based placement with jitter
                          - "poisson_disk": Poisson disk sampling for even spacing
        adaptive_sizing: Automatically adjust size based on number of components
        vary_size: Whether to vary the size of individual components
        vary_factor: How much to vary component sizes (if vary_size is True)
        max_attempts: Maximum attempts to place a component within a given stage (size/domain)
        random_seed: Seed for random number generation (for reproducibility)
        domain_expansion_factor: Factor by which to expand the domain for random center generation (e.g., 0.2 expands [0,1] to [-0.2, 1.2])
        max_domain_expansion_steps: Number of steps for domain expansion.
        verbose: Print verbose output.

    Returns:
        List of Component objects
    """
    components = []

    # Set random seed if provided
    if random_seed is not None:
        np.random.seed(random_seed)

    # Adapt base size if needed
    if adaptive_sizing:
        # Smaller components when we have more of them
        # We need more aggressive sizing for distributed placement
        adapted_size = base_size * (1.0 - 0.08 * (n_components - 2))
        base_size = max(0.1, min(0.8, adapted_size))
    else:
        base_size = min(max(0.1, base_size), 0.8)

    # Function to check if two components overlap
    def components_overlap(comp1, comp2):
        for dim in range(n_features):
            min1, max1 = comp1["rules"][dim]
            min2, max2 = comp2["rules"][dim]
            # If there's a gap in any dimension, the components don't overlap
            if max1 < min2 or max2 < min1:
                return False
        # If we haven't found a gap in any dimension, the components overlap
        return True

    # Generate component sizes
    component_sizes = []
    for i in range(n_components):
        if vary_size:
            size = base_size * (1.0 + vary_factor * (np.random.random() - 0.5))
        else:
            size = base_size
        component_sizes.append(size)

    # Generate locations based on strategy
    if placement_strategy == "grid":
        # Determine grid dimensions (approximate cube root for 3D, etc.)
        grid_dim = int(np.ceil(n_components ** (1 / n_features)))
        grid_spacing = 1.0 / grid_dim

        # Pre-generate all potential grid positions with jitter
        grid_positions = []
        for idx in range(
            min(grid_dim**n_features, n_components * 3)
        ):  # Generate extra positions
            cell_indices = np.unravel_index(idx, [grid_dim] * n_features)
            base_pos = [(idx + 0.5) * grid_spacing for idx in cell_indices]
            # Add jitter (up to 30% of grid spacing)
            jitter = [
                (np.random.random() - 0.5) * 0.3 * grid_spacing
                for _ in range(n_features)
            ]
            position = [min(max(p + j, 0.1), 0.9) for p, j in zip(base_pos, jitter)]
            grid_positions.append(position)

        # Shuffle positions to avoid predictable patterns
        np.random.shuffle(grid_positions)

    elif placement_strategy == "poisson_disk":
        # Simple Poisson disk sampling implementation
        # Start with one random position
        positions = [np.random.random(n_features)]
        active_list = [0]  # Indices of active points
        min_dist = base_size * spacing_factor

        while (
            active_list and len(positions) < n_components * 2
        ):  # Generate extra positions
            idx = np.random.choice(active_list)
            current_pos = positions[idx]

            # Try to place a new point around the current one
            for _ in range(30):  # 30 attempts per active point
                # Generate random direction and distance
                direction = np.random.randn(n_features)
                direction = direction / np.linalg.norm(direction)
                distance = min_dist * (1 + 0.5 * np.random.random())
                new_pos = current_pos + direction * distance

                # Check if new position is valid (within bounds and far enough from other points)
                if np.all(new_pos >= 0.1) and np.all(new_pos <= 0.9):
                    valid = True
                    for p in positions:
                        if np.linalg.norm(new_pos - p) < min_dist:
                            valid = False
                            break
                    if valid:
                        positions.append(new_pos)
                        active_list.append(len(positions) - 1)
                        break

            # If we couldn't place a new point after max attempts, remove from active list
            if _ == 29:  # Used all attempts
                active_list.remove(idx)

        grid_positions = positions

    else:  # random placement
        grid_positions = []

    # Place components
    placed_components = []
    for i in range(n_components):
        original_component_size = component_sizes[i]

        # Select distribution
        # dist = distributions[i % len(distributions)]
        dist = np.random.choice(distributions)

        # Define distribution parameters based on the distribution type
        if dist == "normal":
            dist_params = {"loc": 0.0, "scale": 0.1}
        elif dist == "uniform":
            dist_params = {"low": -0.5, "high": 0.5}
        elif dist == "gamma":
            dist_params = {"shape": 2.0, "scale": 0.2}
        elif dist == "exponential":
            dist_params = {"scale": 0.3}

        # Equal weights for all components
        weight = 1.0 / n_components
        component_placed = False

        # --- Stage 1: Try original size, original domain ---
        active_domain_min, active_domain_max = 0.0, 1.0
        current_attempt_size = original_component_size
        half_current_attempt_size = current_attempt_size / 2

        for attempt_num in range(max_attempts):
            # Get position based on strategy
            # Try grid/poisson point only on the first attempt of this stage for this component
            if (
                (placement_strategy == "grid" or placement_strategy == "poisson_disk")
                and attempt_num == 0
                and i < len(grid_positions)
            ):
                center = grid_positions[i]
            else:  # Random placement or subsequent attempts for grid/poisson
                center = [
                    active_domain_min
                    + (active_domain_max - active_domain_min) * np.random.random()
                    for _ in range(n_features)
                ]

            rules = {
                dim: (
                    center[dim] - half_current_attempt_size,
                    center[dim] + half_current_attempt_size,
                )
                for dim in range(n_features)
            }
            candidate = {"rules": rules}
            overlaps = False
            for placed in placed_components:
                if components_overlap(candidate, placed):
                    overlaps = True
                    break

            if not overlaps:
                placed_components.append(candidate)
                components.append(
                    Component(
                        rules=rules,
                        distribution=dist,
                        dist_params=dist_params,
                        weight=weight,
                    )
                )
                component_placed = True
                if verbose:
                    print(
                        f"  ✓ Component {i+1}/{n_components} placed in original domain."
                    )
                break  # from attempt_num loop

        if component_placed:
            continue  # to next component i

        # --- Stage 2: Try original size, expanded domain ---
        if verbose:
            print(
                f"Info: Component {i+1}/{n_components} not placed in original domain. Trying domain expansion."
            )

        current_attempt_size = (
            original_component_size  # Ensure original size for this stage
        )
        half_current_attempt_size = current_attempt_size / 2

        for expansion_step in range(max_domain_expansion_steps):
            active_domain_min = 0.0 - (expansion_step + 1) * domain_expansion_factor
            active_domain_max = 1.0 + (expansion_step + 1) * domain_expansion_factor
            if verbose:
                print(
                    f"  - Attempting domain [{active_domain_min:.2f}, {active_domain_max:.2f}] for component {i+1}/{n_components} (expansion step {expansion_step+1})"
                )

            for attempt_num in range(max_attempts):
                if (
                    (
                        placement_strategy == "grid"
                        or placement_strategy == "poisson_disk"
                    )
                    and attempt_num == 0
                    and i < len(grid_positions)
                ):
                    # Using pre-calculated grid/poisson point, which might be outside the expanded domain.
                    # This is acceptable; the main benefit of expansion is for random points.
                    center = grid_positions[i]
                else:  # Random placement in the current active_domain
                    center = [
                        active_domain_min
                        + (active_domain_max - active_domain_min) * np.random.random()
                        for _ in range(n_features)
                    ]

                rules = {
                    dim: (
                        center[dim] - half_current_attempt_size,
                        center[dim] + half_current_attempt_size,
                    )
                    for dim in range(n_features)
                }
                candidate = {"rules": rules}
                overlaps = False
                for placed in placed_components:
                    if components_overlap(candidate, placed):
                        overlaps = True
                        break

                if not overlaps:
                    placed_components.append(candidate)
                    components.append(
                        Component(
                            rules=rules,
                            distribution=dist,
                            dist_params=dist_params,
                            weight=weight,
                        )
                    )
                    component_placed = True
                    if verbose:
                        print(
                            f"  ✓ Component {i+1}/{n_components} placed with domain expansion."
                        )
                    break  # from attempt_num loop
            if component_placed:
                break  # from expansion_step loop

        if component_placed:
            continue  # to next component i

        # --- Stage 3: Try reduced size, last active domain ---
        # active_domain_min, active_domain_max are from the last expansion attempt, or (0,1) if expansion didn't run/succeed.
        if verbose:
            print(
                f"Warning: Component {i+1}/{n_components} not placed after domain expansion. Trying with reduced sizes in domain [{active_domain_min:.2f}, {active_domain_max:.2f}]."
            )

        current_reduced_size = original_component_size
        reduction_factor = 0.8
        min_allowable_size = 0.2 * original_component_size

        # Ensure at least one attempt if original size is already small
        if (
            current_reduced_size < min_allowable_size
        ):  # This check might be redundant if base_size is constrained well
            min_allowable_size = (
                current_reduced_size * reduction_factor
            )  # allow one reduction at least

        while (
            current_reduced_size > min_allowable_size and not component_placed
        ):  # Use > to ensure it reduces at least once before check
            current_reduced_size *= reduction_factor
            if (
                current_reduced_size < min_allowable_size
            ):  # Ensure we don't go below min_allowable_size
                current_reduced_size = min_allowable_size

            half_current_reduced_size = current_reduced_size / 2
            if verbose:
                print(
                    f"  - Attempting with size reduced to {current_reduced_size/original_component_size:.1%} of original for component {i+1}/{n_components}"
                )

            for attempt_num in range(
                max_attempts
            ):  # Use max_attempts for reduced size attempts
                if (
                    (
                        placement_strategy == "grid"
                        or placement_strategy == "poisson_disk"
                    )
                    and attempt_num == 0
                    and i < len(grid_positions)
                ):
                    center = grid_positions[i]
                else:  # Random placement in the last active_domain
                    center = [
                        active_domain_min
                        + (active_domain_max - active_domain_min) * np.random.random()
                        for _ in range(n_features)
                    ]

                rules = {
                    dim: (
                        center[dim] - half_current_reduced_size,
                        center[dim] + half_current_reduced_size,
                    )
                    for dim in range(n_features)
                }
                candidate = {"rules": rules}
                overlaps = False
                for placed in placed_components:
                    if components_overlap(candidate, placed):
                        overlaps = True
                        break

                if not overlaps:
                    placed_components.append(candidate)
                    components.append(
                        Component(
                            rules=rules,
                            distribution=dist,
                            dist_params=dist_params,
                            weight=weight,
                        )
                    )
                    component_placed = True
                    if verbose:
                        print(
                            f"  ✓ Successfully placed component {i+1}/{n_components} at {current_reduced_size/original_component_size:.1%} of original size."
                        )
                    break  # from attempt_num loop

            if component_placed:
                break  # from while size reduction loop

            if (
                current_reduced_size == min_allowable_size
            ):  # If we've reached min size and failed, stop.
                break

        if not component_placed and verbose:
            print(
                f"Critical: Failed to place component {i+1}/{n_components} even with progressive domain expansion and size reduction. "
                f"Consider reducing base_size, decreasing n_components, increasing spacing_factor, or using a different placement strategy."
            )
            # Component is skipped if not placed

    return components
