import torch
from typing import Optional, List
from .base_intializer import CutPointInitializer


class RandomInitializer(CutPointInitializer):
    """Initializes cut points randomly within feature limits."""

    def __init__(
        self,
        X=None,
        interval_range: Optional[List[float]] = None,
        n_components: Optional[int] = None,
        **kwargs,
    ):
        """
        Args:
            interval_range: A list [min_size, max_size] specifying the relative
                minimum and maximum size of intervals. Defaults to [0.2, 0.4].
        """
        if interval_range is None:
            if n_components is not None:
                fraction = max(1 / (n_components + 1), 0.04)
                interval_range = [fraction, fraction + 0.01]
            else:
                interval_range = [0.2, 0.4]
        if not (isinstance(interval_range, list) and len(interval_range) == 2):
            raise ValueError(
                "interval_range must be a list of two floats [min_size, max_size]"
            )

        min_interval_size = interval_range[0]
        interval_span = interval_range[1] - interval_range[0]

        if not (
            0 < min_interval_size < 1
            and interval_span >= 0
            and min_interval_size + interval_span <= 1
        ):
            raise ValueError(
                "Invalid interval_range values. Must satisfy 0 < min < 1, span >= 0, min + span <= 1."
            )

        self.min_interval_size = min_interval_size
        self.interval_span = interval_span

    def initialize(
        self,
        n_features: int,
        predicates_per_feature: int,
        data_limits: torch.Tensor,
    ) -> torch.Tensor:
        """Generates random initial cut points."""
        # Create tensor for cut points: [feature, start/end, predicate_index]
        init = torch.zeros(n_features, 2, predicates_per_feature, dtype=torch.float32)

        min_feature = data_limits[:, 0]  # Shape: [n_features]
        max_feature = data_limits[:, 1]  # Shape: [n_features]
        feature_range = max_feature - min_feature  # Shape: [n_features]

        # Handle features with zero range to avoid division by zero or NaN
        # Use a small default range (e.g., 1.0) if the actual range is zero or negative
        feature_range = torch.where(
            feature_range > 0,
            feature_range,
            torch.tensor(1.0, device=feature_range.device),
        )

        for f in range(n_features):
            for p in range(predicates_per_feature):
                # Random relative interval size for this specific feature/predicate pair
                size = (
                    torch.rand(1).item() * self.interval_span + self.min_interval_size
                )

                # Random relative start position that ensures interval fits within [0, 1]
                start = torch.rand(1).item() * (
                    1 - size
                )  # Relative start within [0, 1-size]

                # Scale to feature range and offset by feature minimum
                scaled_start = start * feature_range[f] + min_feature[f]
                scaled_end = (start + size) * feature_range[f] + min_feature[f]

                # Store in initialization tensor, ensuring start < end
                init[f, 0, p] = min(scaled_start, scaled_end)
                init[f, 1, p] = max(scaled_start, scaled_end)

        return init
