import numpy as np
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
import warnings
import math

# Assuming this lives alongside other generators that import from a parent module
from ..mixture_gen import Component


@dataclass
class TreeComponentMetadata:
    """Metadata about a component generated from a tree leaf node."""

    component_id: int
    leaf_depth: int
    leaf_bounds: Dict[int, Tuple[float, float]]
    weight: float
    y_distribution: str
    y_dist_params: Dict[str, Any]
    x_distribution: str
    x_dist_params: Dict[str, Any]


@dataclass
class TreeNode:
    """A node in the recursive splitting tree."""

    bounds: Dict[int, Tuple[float, float]]
    depth: int
    parent: Optional["TreeNode"] = None
    children: List["TreeNode"] = field(default_factory=list)
    split_feature: Optional[int] = None
    split_value: Optional[float] = None

    @property
    def is_leaf(self) -> bool:
        """A node is a leaf if it has no children."""
        return not self.children


def tree_based_component_generator(
    n_components: int,
    n_features: int,
    distributions: List[str],
    max_depth: int = 8,
    expansion_strategy: str = "breadth_first",
    domain_bounds: Tuple[float, float] = (0.0, 1.0),
    min_leaf_size_fraction: float = 0.1,
    empty_leaf_probability: float = 0.0,
    leaf_x_distribution: str = "uniform",
    gaussian_std_fraction: float = 0.25,
    random_seed: Optional[int] = None,
    **kwargs,
) -> Tuple[List[Component], List[TreeComponentMetadata], List["TreeNode"]]:
    """
    Generates non-overlapping components and identifies empty regions.

    This generator defines the active components and returns the geometric
    definitions of the empty spaces for a higher-level generator to fill.

    Args:
        (Same as before)
        **kwargs: Catches unused parameters like background noise settings,
                  which are now handled by the DatasetGenerator.

    Returns:
        A tuple containing:
        - A list of active Component objects.
        - A list of TreeComponentMetadata for the active components.
        - A list of TreeNode objects representing the empty leaves.
    """
    if random_seed is not None:
        rng = np.random.default_rng(random_seed)
    else:
        rng = np.random.default_rng()

    # --- Validate Parameters ---
    if not (0.0 <= empty_leaf_probability < 1.0):
        raise ValueError("empty_leaf_probability must be between 0.0 and 1.0")

    # --- Tree Initialization ---
    if empty_leaf_probability > 0:
        # We need to aim for more leaves total to get n_components active ones
        target_total_leaves = math.ceil(n_components / (1.0 - empty_leaf_probability))
    else:
        target_total_leaves = n_components

    # Ensure there's at least one leaf to be designated as empty if needed
    if (
        target_total_leaves <= n_components
        and kwargs.get("background_noise_fraction", 0) > 0
    ):
        target_total_leaves = n_components + 1

    root_bounds = {i: domain_bounds for i in range(n_features)}
    root = TreeNode(bounds=root_bounds, depth=0)
    root_domain_width = domain_bounds[1] - domain_bounds[0]
    min_feature_width = root_domain_width * min_leaf_size_fraction

    leaves = [root]
    nodes_to_split = [root]

    while len(leaves) < target_total_leaves and nodes_to_split:
        if expansion_strategy == "breadth_first":
            node_to_split = nodes_to_split.pop(0)
        else:  # depth_first
            node_to_split = nodes_to_split.pop()

        if node_to_split.depth >= max_depth:
            continue

        possible_split_features = list(range(n_features))
        rng.shuffle(possible_split_features)

        split_successful = False
        for split_feature in possible_split_features:
            min_b, max_b = node_to_split.bounds[split_feature]
            if (max_b - min_b) >= min_feature_width * 2:
                split_range_min = min_b + min_feature_width
                split_range_max = max_b - min_feature_width

                if split_range_min >= split_range_max:
                    continue

                split_value = rng.uniform(split_range_min, split_range_max)
                left_bounds = node_to_split.bounds.copy()
                left_bounds[split_feature] = (min_b, split_value)
                left_child = TreeNode(
                    bounds=left_bounds,
                    depth=node_to_split.depth + 1,
                    parent=node_to_split,
                )

                right_bounds = node_to_split.bounds.copy()
                right_bounds[split_feature] = (split_value, max_b)
                right_child = TreeNode(
                    bounds=right_bounds,
                    depth=node_to_split.depth + 1,
                    parent=node_to_split,
                )

                node_to_split.children = [left_child, right_child]
                node_to_split.split_feature = split_feature
                node_to_split.split_value = split_value

                leaves.remove(node_to_split)
                leaves.extend([left_child, right_child])
                nodes_to_split.extend([left_child, right_child])
                split_successful = True
                break

    # --- Convert Leaves to Components and Identify Empty Leaves ---
    if len(leaves) >= n_components:
        active_leaves = list(rng.choice(leaves, size=n_components, replace=False))
    else:
        active_leaves = leaves
        warnings.warn(
            f"Could not generate the target {n_components} components. "
            f"Generated {len(active_leaves)} instead. This may be due to "
            f"max_depth ({max_depth}) or min_leaf_size_fraction ({min_leaf_size_fraction}) constraints.",
            UserWarning,
        )

    empty_leaves = [leaf for leaf in leaves if leaf not in active_leaves]

    if not active_leaves:
        return [], [], empty_leaves

    components = []
    metadata_list = []
    weight = 1.0 / len(active_leaves)  # Initial weight, will be re-scaled later

    for i, leaf in enumerate(active_leaves):
        x_dist_params = {}
        if leaf_x_distribution == "gaussian":
            loc = {feat: (b[0] + b[1]) / 2 for feat, b in leaf.bounds.items()}
            covariance = {
                feat: ((b[1] - b[0]) * gaussian_std_fraction) ** 2
                for feat, b in leaf.bounds.items()
            }
            x_dist_params = {"loc": loc, "covariance": covariance}

        y_dist_name = rng.choice(distributions)
        y_dist_params = {}
        if y_dist_name == "normal":
            y_dist_params = {"loc": rng.uniform(-1, 1), "scale": rng.uniform(0.1, 0.5)}
        elif y_dist_name == "uniform":
            y_dist_params = {"low": 0, "high": rng.uniform(0.5, 1.5)}
        elif y_dist_name == "gamma":
            y_dist_params = {"shape": rng.uniform(1, 5), "scale": rng.uniform(0.1, 0.5)}
        elif y_dist_name == "exponential":
            y_dist_params = {"scale": rng.uniform(0.2, 1.0)}

        component = Component(
            rules=leaf.bounds,
            distribution=y_dist_name,
            dist_params=y_dist_params,
            weight=weight,
        )
        components.append(component)

        metadata = TreeComponentMetadata(
            component_id=i,
            leaf_depth=leaf.depth,
            leaf_bounds=leaf.bounds,
            weight=weight,
            y_distribution=y_dist_name,
            y_dist_params=y_dist_params,
            x_distribution=leaf_x_distribution,
            x_dist_params=x_dist_params,
        )
        metadata_list.append(metadata)

    return components, metadata_list, empty_leaves
