import torch
from typing import Optional, List
from .base_intializer import CutPointInitializer
import warnings


class GuidedRandomInitializer(CutPointInitializer):
    """
    Initializes cut points randomly, ensuring each rule's combined intervals
    cover at least one data sample from X.
    """

    def __init__(
        self,
        X: torch.Tensor,
        interval_range: Optional[List[float]] = None,
        max_retries_per_rule: int = 10000,
        verbose=False,
        n_components: Optional[int] = None,
        discrete_features: Optional[List[bool]] = None,
        max_constrained_features: Optional[int] = None,
        unconstrained_feature_interval_range: Optional[List[float]] = None,
    ):
        """
        Args:
            X: Input data tensor of shape [n_samples, n_features].
            interval_range: A list [min_size, max_size] specifying the relative
                minimum and maximum size of intervals for constrained features.
                Defaults to [0.2, 0.4] or a fraction based on n_components.
            max_retries_per_rule: Maximum attempts to find a valid initialization
                for a single rule before raising an error. Defaults to 10000.
            verbose: If True, print additional information. Defaults to False.
            n_components: Optional number of components, used to determine interval_range
                if not explicitly provided.
            discrete_features: Optional list of booleans indicating which features are
                discrete. If provided, constrained discrete features will have intervals
                centered around actual values from X. Defaults to None.
            max_constrained_features: Optional. If set, limits the number of features
                on which the "at least one sample" constraint must hold. These features
                are chosen randomly. If None or >= n_features, all features are constrained.
            unconstrained_feature_interval_range: Optional. A list [min_size, max_size]
                for features not selected for the coverage constraint. If None,
                these features use the main `interval_range` for sizing.
        """
        if interval_range is None:
            if n_components is not None:
                fraction = max(1 / (n_components + 1), 0.04)
                interval_range = [fraction, fraction + 0.03]
            else:
                interval_range = [0.2, 0.4]
        if not (isinstance(interval_range, list) and len(interval_range) == 2):
            raise ValueError(
                "interval_range must be a list of two floats [min_size, max_size]"
            )
        self.verbose = verbose
        self.X = X  # X is stored on the instance
        min_interval_size = interval_range[0]
        interval_span = interval_range[1] - interval_range[0]

        if not (
            0 < min_interval_size < 1
            and interval_span >= 0
            and min_interval_size + interval_span <= 1
        ):
            raise ValueError(
                "Invalid interval_range values. Must satisfy 0 < min < 1, span >= 0, min + span <= 1."
            )

        self.min_interval_size = min_interval_size
        self.interval_span = interval_span
        self.max_retries_per_rule = max_retries_per_rule
        if discrete_features is None:
            # Automatically detect discrete features if not provided.
            # A feature is considered discrete if it has less than 10 unique values.
            if self.X is not None and self.X.shape[1] > 0:
                # Assumes 'torch' is imported at the top of the file, as X is a torch.Tensor.
                # If 'torch.unique' or 'numel' are not available, an import torch statement
                # would be needed at the beginning of the file.
                self.discrete_features = [
                    torch.unique(self.X[:, i]).numel() <= 2
                    for i in range(self.X.shape[1])
                ]
                if self.verbose:
                    print(f"Auto-detected discrete_features: {self.discrete_features}")
            else:
                # If X is None or has no features, initialize discrete_features as an empty list.
                self.discrete_features = []
        else:
            self.discrete_features = discrete_features
        self.max_constrained_features = max_constrained_features

        self.unconstrained_feature_min_interval_size: Optional[float] = None
        self.unconstrained_feature_interval_span: Optional[float] = None
        if unconstrained_feature_interval_range is not None:
            if not (
                isinstance(unconstrained_feature_interval_range, list)
                and len(unconstrained_feature_interval_range) == 2
            ):
                raise ValueError(
                    "unconstrained_feature_interval_range must be a list of two floats [min_size, max_size]"
                )
            umin, umax = unconstrained_feature_interval_range
            uspan = umax - umin
            if not (
                0 < umin < 1 and uspan >= 0 and umin + uspan <= 1
            ):  # Ensure 0 < umin < 1, not 0 <= umin <= 1
                raise ValueError(
                    "Invalid unconstrained_feature_interval_range values. Must satisfy 0 < min < 1, span >= 0, min + span <= 1."
                )
            self.unconstrained_feature_min_interval_size = umin
            self.unconstrained_feature_interval_span = uspan

    def _generate_one_rule_intervals(
        self,
        n_features: int,
        min_feature_vals: torch.Tensor,  # Shape [n_features, 1]
        feature_range_vals: torch.Tensor,  # Shape [n_features, 1]
        device: torch.device,
        constrained_feature_indices: torch.Tensor,
    ) -> torch.Tensor:
        """Generates intervals for a single rule across all features."""
        intervals = torch.zeros((n_features, 2), device=device)

        constrained_mask = torch.zeros(n_features, dtype=torch.bool, device=device)
        if constrained_feature_indices.numel() > 0:
            constrained_mask[constrained_feature_indices] = True

        unconstrained_mask_bool = ~constrained_mask

        # Determine min_size and span for all features based on whether they are constrained
        current_min_sizes = torch.full(
            (n_features, 1), self.min_interval_size, device=device
        )
        current_spans = torch.full((n_features, 1), self.interval_span, device=device)

        if self.unconstrained_feature_min_interval_size is not None:
            # Apply unconstrained params only to unconstrained features
            current_min_sizes[unconstrained_mask_bool] = (
                self.unconstrained_feature_min_interval_size
            )
            current_spans[unconstrained_mask_bool] = (
                self.unconstrained_feature_interval_span
            )
        # Else, unconstrained features already have the default (main) interval_range params

        # Generate random relative interval sizes for all features for this rule
        size_per_feature = (
            torch.rand(n_features, 1, device=device) * current_spans + current_min_sizes
        )  # Shape: [n_features, 1]

        for f in range(n_features):
            is_feature_constrained = constrained_mask[f].item()

            is_discrete_and_guidable = (
                self.X
                is not None  # Guiding for discrete features needs X from the instance
                and self.discrete_features is not None
                and f < len(self.discrete_features)
                and self.discrete_features[f]
            )

            if is_feature_constrained and is_discrete_and_guidable:
                unique_values = torch.unique(self.X[:, f])  # Use self.X
                if len(unique_values) > 0:
                    selected_value_tensor = unique_values[
                        torch.randint(0, len(unique_values), (1,), device=device)
                    ]
                    selected_value = selected_value_tensor.item()

                    half_size_val = (
                        size_per_feature[f, 0].item()
                        * feature_range_vals[f, 0].item()
                        / 2.0
                    )

                    start_val = selected_value - half_size_val
                    end_val = selected_value + half_size_val

                    intervals[f, 0] = start_val
                    intervals[f, 1] = end_val
                else:  # Fallback for constrained discrete if no unique values
                    self._set_random_interval(
                        intervals,
                        f,
                        size_per_feature[f, 0].item(),
                        min_feature_vals[f, 0].item(),
                        feature_range_vals[f, 0].item(),
                        device,
                    )
            else:
                # Unconstrained feature OR Constrained continuous OR Constrained discrete without guidance
                self._set_random_interval(
                    intervals,
                    f,
                    size_per_feature[f, 0].item(),
                    min_feature_vals[f, 0].item(),
                    feature_range_vals[f, 0].item(),
                    device,
                )
        return intervals

    def _set_random_interval(
        self,
        intervals: torch.Tensor,
        feature_idx: int,
        size: float,  # This is relative size [0,1]
        min_val: float,
        range_val: float,
        device: torch.device,
    ):
        """Helper method to set a random interval for a single feature."""
        start_relative = torch.rand(1, device=device).item() * (1.0 - size)
        scaled_start = start_relative * range_val + min_val
        scaled_end = (start_relative + size) * range_val + min_val
        intervals[feature_idx, 0] = scaled_start
        intervals[feature_idx, 1] = scaled_end

    def _check_rule_validity(
        self,
        rule_intervals: torch.Tensor,  # Shape [n_features, 2]
        X_data_to_check: torch.Tensor,  # Shape [n_samples, n_features]
        constrained_feature_indices: torch.Tensor,  # Shape [k]
    ) -> bool:
        """Checks if the rule intervals cover at least one sample in X on constrained features."""
        # X_data_to_check is self.X, already device-checked and validated in initialize()
        if (
            X_data_to_check is None or X_data_to_check.shape[0] == 0
        ):  # Check if X is empty
            if self.verbose:
                warnings.warn(
                    "Input data X is empty, cannot guide initialization. Rule considered valid."
                )
            return True

        if constrained_feature_indices.numel() == 0:  # No features are constrained
            return True  # Rule is trivially valid as no constraints need to be met

        X_constrained = X_data_to_check[:, constrained_feature_indices]
        lower_bounds_constrained = rule_intervals[constrained_feature_indices, 0]
        upper_bounds_constrained = rule_intervals[constrained_feature_indices, 1]

        within_bounds_constrained = (X_constrained >= lower_bounds_constrained) & (
            X_constrained <= upper_bounds_constrained
        )

        all_constrained_features_satisfy = torch.all(within_bounds_constrained, dim=1)
        rule_is_valid = torch.any(all_constrained_features_satisfy)

        return rule_is_valid.item()

    def initialize(
        self,
        n_features: int,
        predicates_per_feature: int,  # Renaming this conceptually to n_rules
        data_limits: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        n_rules = predicates_per_feature

        # self.X is used, which was set and validated (for type) in __init__
        # Further validation (shape, device) for self.X happens here
        current_X = self.X
        if current_X is not None:
            if current_X.ndim != 2 or current_X.shape[1] != n_features:
                raise ValueError(
                    f"Instance data 'X' must have shape [n_samples, n_features], "
                    f"expected n_features={n_features}, but got shape {current_X.shape}"
                )
            if current_X.device != data_limits.device:
                if self.verbose:
                    warnings.warn(
                        f"Instance data 'X' is on device {current_X.device} but data_limits is on {data_limits.device}. Moving X to {data_limits.device}."
                    )
                current_X = current_X.to(data_limits.device)
                self.X = current_X  # Update instance X if moved

        # Determine constrained_feature_indices
        constrained_feature_indices: torch.Tensor
        num_to_constrain = n_features
        if (
            self.max_constrained_features is not None
            and self.max_constrained_features < n_features
        ):
            num_to_constrain = self.max_constrained_features

        if num_to_constrain <= 0:
            constrained_feature_indices = torch.tensor(
                [], dtype=torch.long, device=data_limits.device
            )
        elif num_to_constrain < n_features:
            perm = torch.randperm(n_features, device=data_limits.device)
            constrained_feature_indices = perm[:num_to_constrain]
        else:  # All features are constrained (num_to_constrain == n_features)
            constrained_feature_indices = torch.arange(
                n_features, device=data_limits.device
            )

        if (
            current_X is None
            and constrained_feature_indices.numel() > 0
            and self.verbose
        ):
            warnings.warn(
                "Input data 'X' not provided to initializer, but constraints were expected for some features. "
                "Proceeding with unguided random initialization for all features within rules.",
                RuntimeWarning,
            )

        init_tensor = torch.zeros(
            n_features, 2, n_rules, dtype=torch.float32, device=data_limits.device
        )

        min_f_vals = data_limits[:, 0].unsqueeze(1)
        max_f_vals = data_limits[:, 1].unsqueeze(1)
        range_f_vals = max_f_vals - min_f_vals

        range_f_vals = torch.where(
            range_f_vals
            > 1e-6,  # Use a small epsilon instead of > 0 for float comparison
            range_f_vals,
            torch.tensor(1.0, device=range_f_vals.device).expand_as(range_f_vals),
        )

        for r in range(n_rules):
            rule_initialized_successfully = False
            last_generated_intervals_for_rule = None

            # Max retries: if 0, loop runs once. If 1, loop runs twice (1 retry).
            for attempt_num in range(self.max_retries_per_rule + 1):
                current_rule_intervals = self._generate_one_rule_intervals(
                    n_features,
                    min_f_vals,
                    range_f_vals,
                    init_tensor.device,
                    constrained_feature_indices,
                )
                last_generated_intervals_for_rule = current_rule_intervals

                needs_check = (
                    current_X is not None and constrained_feature_indices.numel() > 0
                )

                if not needs_check or self._check_rule_validity(
                    current_rule_intervals, current_X, constrained_feature_indices
                ):
                    init_tensor[:, :, r] = current_rule_intervals
                    rule_initialized_successfully = True
                    break

            if not rule_initialized_successfully:
                # This means all attempts (including retries) failed the validity check,
                # or no check was needed but something went wrong (should use last_generated).
                if last_generated_intervals_for_rule is not None:
                    init_tensor[:, :, r] = last_generated_intervals_for_rule
                    if (
                        self.verbose
                        and constrained_feature_indices.numel() > 0
                        and current_X is not None
                    ):  # Only warn if check was active
                        warnings.warn(
                            f"Rule {r}: Failed to find an initialization satisfying constraints on "
                            f"{constrained_feature_indices.numel()} features after {self.max_retries_per_rule} retries. "
                            f"Using last generated intervals.",
                            RuntimeWarning,
                        )
                elif (
                    self.verbose
                ):  # Should not happen if _generate_one_rule_intervals always returns
                    warnings.warn(
                        f"Rule {r}: Could not generate any intervals. Rule will have default (zero) intervals.",
                        RuntimeWarning,
                    )
        return init_tensor
