import torch
from typing import Optional
from .base_intializer import CutPointInitializer


class SampleBasedInitializer(CutPointInitializer):
    """
    Initializes cut points by first selecting samples, then building intervals around them.
    This approach guarantees coverage and scales well to high-dimensional data.

    This initializer solves the curse of dimensionality problem that affects random approaches:
    - Random intervals become exponentially unlikely to contain samples in high dimensions
    - GuidedRandomInitializer requires expensive retries that often fail
    - SampleBasedInitializer guarantees 100% rule coverage by construction

    Performance Characteristics:
    - O(n_rules * n_features) time complexity
    - Always achieves 100% rule coverage (each rule covers at least 1 sample)
    - Scales efficiently to hundreds of dimensions
    - No retry loops or convergence issues

    When to Use:
    - High-dimensional data (>10 features): Always prefer over random approaches
    - When guaranteed coverage is required
    - When fast, deterministic initialization is needed
    - For sparse or challenging data distributions with discrete or gappy features.

    Sample Selection Strategies:
    - 'random': Fast uniform sampling. Good for most cases.
    - 'diverse': K-means++ style selection for spatial diversity. Better for clustering-like rules.
    - 'y_diverse': Selects samples with diverse target values. Good for classification tasks.

    Example Usage:
        # Basic usage with auto-sizing enabled by default
        initializer = SampleBasedInitializer(X=X, sample_strategy='random')

        # For classification with diverse target coverage
        initializer = SampleBasedInitializer(X=X, y=y, sample_strategy='y_diverse')

        # Forcing a fixed interval size by disabling auto-sizing
        initializer = SampleBasedInitializer(X=X, sample_strategy='diverse', interval_size=0.15, auto_size_intervals=False)

        # Via factory function
        initializer = get_initializer('sample_based', X=X, sample_strategy='random')
    """

    def __init__(
        self,
        X: torch.Tensor,
        sample_strategy: str = "diverse",
        interval_size: float = 0.1,
        auto_size_intervals: bool = True,
        jitter_strength: float = 0.0,
        y: Optional[torch.Tensor] = None,
        verbose: bool = False,
        **kwargs,
    ):
        """
        Args:
            X: Input data tensor of shape [n_samples, n_features].
            sample_strategy: Strategy for selecting samples. Options:
                - "random": Random sample selection
                - "diverse": Try to select diverse samples across feature space
                - "y_diverse": Select samples with diverse y values (requires y)
            interval_size: The base relative size of intervals, used as a *minimum* size
                         if auto_size_intervals is True.
            auto_size_intervals: If True, dynamically expands interval width to include at
                                 least one neighboring sample in each dimension. This helps
                                 handle discrete and sparse features. Defaults to True.
            jitter_strength: Strength of the random variation applied to the interval size.
                             A value of 0.1 means the size can vary by +/- 10%. Defaults to 0.0.
            y: Optional target values for y_diverse strategy.
            verbose: If True, print additional information.
        """
        if X is None or X.ndim != 2:
            raise ValueError("X must be a 2D tensor with shape [n_samples, n_features]")

        self.X = X
        self.sample_strategy = sample_strategy
        self.interval_size = interval_size
        self.auto_size_intervals = auto_size_intervals
        self.jitter_strength = jitter_strength
        self.y = y
        self.verbose = verbose

        # Store selected samples after initialization for verification
        self.selected_indices = None
        self.selected_samples = None

        if not (0 < interval_size <= 1):
            raise ValueError("interval_size must be between 0 and 1")

        if not (0.0 <= jitter_strength <= 1.0):
            raise ValueError("jitter_strength must be between 0.0 and 1.0")

        if sample_strategy not in ["random", "diverse", "y_diverse"]:
            raise ValueError(
                "sample_strategy must be 'random', 'diverse', or 'y_diverse'"
            )

        if sample_strategy == "y_diverse" and y is None:
            raise ValueError("y_diverse strategy requires y to be provided")

    def _select_samples(self, n_rules: int) -> torch.Tensor:
        """Select samples using the specified strategy."""
        n_samples = self.X.shape[0]

        if n_samples == 0:
            raise ValueError("Cannot select samples from empty dataset")

        if self.sample_strategy == "random":
            # Simple random sampling with replacement
            selected_indices = torch.randint(
                0, n_samples, (n_rules,), device=self.X.device
            )

        elif self.sample_strategy == "diverse":
            # Try to select diverse samples across feature space
            # Use k-means++ style initialization for diversity
            selected_indices = self._select_diverse_samples(n_rules)

        elif self.sample_strategy == "y_diverse":
            # Select samples with diverse y values
            selected_indices = self._select_y_diverse_samples(n_rules)
        else:
            # This should never happen due to validation in __init__, but needed for type checker
            raise ValueError(f"Unknown sample_strategy: {self.sample_strategy}")

        return selected_indices

    def _select_diverse_samples(self, n_rules: int) -> torch.Tensor:
        """Select diverse samples using k-means++ style approach."""
        n_samples = self.X.shape[0]
        selected_indices = []

        # First sample is random
        first_idx = torch.randint(0, n_samples, (1,), device=self.X.device)
        selected_indices.append(first_idx.item())

        # Select remaining samples to maximize distance from already selected
        for _ in range(n_rules - 1):
            if len(selected_indices) >= n_samples:
                # If we've selected all unique samples, sample with replacement
                next_idx = torch.randint(0, n_samples, (1,), device=self.X.device)
                selected_indices.append(next_idx.item())
                continue

            selected_samples = self.X[selected_indices]  # [n_selected, n_features]

            # Compute distances from all samples to closest selected sample
            distances = torch.cdist(self.X, selected_samples)  # [n_samples, n_selected]
            min_distances = torch.min(distances, dim=1)[0]  # [n_samples]

            # Sample proportional to squared distance (k-means++ style)
            weights = min_distances**2
            weights = weights / weights.sum()

            # Sample according to weights
            next_idx = torch.multinomial(weights, 1).item()
            selected_indices.append(next_idx)

        return torch.tensor(selected_indices, device=self.X.device)

    def _select_y_diverse_samples(self, n_rules: int) -> torch.Tensor:
        """Select samples with diverse y values."""
        if self.y is None:
            raise ValueError("y_diverse strategy requires y to be provided")

        n_samples = self.X.shape[0]

        # Get unique y values and their counts
        unique_y, counts = torch.unique(self.y, return_counts=True)

        if len(unique_y) == 1:
            # Only one class, fall back to random sampling
            if self.verbose:
                print("Only one unique y value found, falling back to random sampling")
            return torch.randint(0, n_samples, (n_rules,), device=self.X.device)

        selected_indices = []

        # Try to select samples from different y values
        for i in range(n_rules):
            # Choose y value to sample from (cycle through unique values)
            target_y = unique_y[i % len(unique_y)]

            # Find samples with this y value
            candidates = torch.where(self.y == target_y)[0]

            # Random selection from candidates
            selected_idx = candidates[
                torch.randint(0, len(candidates), (1,), device=self.X.device)
            ]
            selected_indices.append(selected_idx.item())

        return torch.tensor(selected_indices, device=self.X.device)

    def _build_intervals_around_samples(
        self, selected_samples: torch.Tensor, data_limits: torch.Tensor
    ) -> torch.Tensor:
        """
        Builds intervals around selected samples.

        This method can dynamically adjust the interval size based on the data
        distribution to handle sparse or discrete features better, ensuring the
        initial interval is large enough to span to a neighboring data point.
        """
        n_rules, n_features = selected_samples.shape
        intervals = torch.zeros(n_features, 2, n_rules, device=self.X.device)

        min_vals = data_limits[:, 0]
        max_vals = data_limits[:, 1]
        feature_ranges = max_vals - min_vals
        # Handle features with zero range to avoid division by zero
        feature_ranges = torch.where(
            feature_ranges > 1e-8,
            feature_ranges,
            torch.tensor(1.0, device=feature_ranges.device),
        )

        # Pre-compute sorted unique values for each feature for efficient neighbor finding
        sorted_unique_X = None
        if self.auto_size_intervals:
            if self.verbose:
                print(
                    "Auto-sizing intervals: Pre-computing sorted unique values for features."
                )
            sorted_unique_X = [
                torch.unique(self.X[:, i], sorted=True) for i in range(n_features)
            ]

        for rule_idx in range(n_rules):
            for feat_idx in range(n_features):
                sample_val_tensor = selected_samples[rule_idx, feat_idx]
                sample_val = sample_val_tensor.item()
                feat_range = feature_ranges[feat_idx].item()
                min_val = min_vals[feat_idx].item()
                max_val = max_vals[feat_idx].item()

                # 1. Calculate base interval width from `interval_size`
                base_interval_width = self.interval_size * feat_range

                # 2. Calculate neighbor-based width if auto-sizing is enabled
                auto_sized_width = 0.0
                if self.auto_size_intervals and sorted_unique_X is not None:
                    sorted_vals = sorted_unique_X[feat_idx]

                    if len(sorted_vals) > 1:
                        # Find distance to nearest neighbor in this dimension
                        # `torch.searchsorted` finds the index where an element should be inserted
                        # to maintain order. This is perfect for finding neighbors.
                        insertion_idx = torch.searchsorted(
                            sorted_vals, sample_val_tensor
                        )

                        dist_to_lower = float("inf")
                        if insertion_idx > 0:
                            dist_to_lower = (
                                sample_val - sorted_vals[insertion_idx - 1].item()
                            )

                        dist_to_upper = float("inf")
                        # Find the actual upper neighbor. If sample_val is in sorted_vals,
                        # the value at insertion_idx might be sample_val itself, so we check the next one.
                        upper_idx = insertion_idx
                        if (
                            upper_idx < len(sorted_vals)
                            and sorted_vals[upper_idx] <= sample_val_tensor
                        ):
                            upper_idx += 1

                        if upper_idx < len(sorted_vals):
                            dist_to_upper = sorted_vals[upper_idx].item() - sample_val

                        # The distance to the nearest neighbor
                        min_dist = min(dist_to_lower, dist_to_upper)

                        if min_dist != float("inf") and min_dist > 1e-8:
                            # The box should be large enough to *include* the neighbor.
                            # A symmetric box requires a half-width of `min_dist`,
                            # so the full width is `2 * min_dist`.
                            auto_sized_width = 2.0 * min_dist

                # 3. Determine the final interval width
                final_interval_width = max(base_interval_width, auto_sized_width)

                # 4. Apply jitter if specified
                if self.jitter_strength > 0:
                    random_scaling = (torch.rand(1).item() - 0.5) * 2
                    scale_factor = 1.0 + (random_scaling * self.jitter_strength)
                    final_interval_width *= scale_factor
                    final_interval_width = max(0, final_interval_width)

                # 5. Build the interval, centered on the sample value
                half_width = final_interval_width / 2.0
                start_val = sample_val - half_width
                end_val = sample_val + half_width

                # 6. Clamp to data limits
                start_val = max(start_val, min_val)
                end_val = min(end_val, max_val)

                # 7. Ensure interval is valid (start < end)
                if start_val >= end_val:
                    # If the interval collapsed, create a tiny one around the sample
                    if max_val > min_val:
                        epsilon = 1e-6 * feat_range
                        start_val = max(sample_val - epsilon, min_val)
                        end_val = min(sample_val + epsilon, max_val)
                        # Final check to prevent s >= e after creating epsilon interval
                        if start_val >= end_val:
                            start_val = min_val
                            end_val = max_val
                    else:  # The feature has only one value
                        start_val = min_val
                        end_val = max_val

                intervals[feat_idx, 0, rule_idx] = start_val
                intervals[feat_idx, 1, rule_idx] = end_val

        return intervals

    def initialize(
        self,
        n_features: int,
        predicates_per_feature: int,
        data_limits: torch.Tensor,
    ) -> torch.Tensor:
        """Initialize cut points by selecting samples and building intervals around them."""
        n_rules = predicates_per_feature

        # Validate input data
        if self.X.shape[1] != n_features:
            raise ValueError(
                f"X has {self.X.shape[1]} features but n_features={n_features}"
            )

        # Move X to same device as data_limits if needed
        if self.X.device != data_limits.device:
            if self.verbose:
                print(f"Moving X from {self.X.device} to {data_limits.device}")
            self.X = self.X.to(data_limits.device)
            if self.y is not None:
                self.y = self.y.to(data_limits.device)

        # Select samples using specified strategy
        selected_indices = self._select_samples(n_rules)
        selected_samples = self.X[selected_indices]  # [n_rules, n_features]

        # Store for verification
        self.selected_indices = selected_indices
        self.selected_samples = selected_samples

        if self.verbose:
            print(f"Selected {n_rules} samples using '{self.sample_strategy}' strategy")
            if self.y is not None and self.sample_strategy == "y_diverse":
                selected_y = self.y[selected_indices]
                unique_selected_y = torch.unique(selected_y)
                print(
                    f"Selected samples cover {len(unique_selected_y)} unique y values"
                )

        # Build intervals around selected samples
        intervals = self._build_intervals_around_samples(selected_samples, data_limits)

        return intervals
