import torch
from .base_intializer import CutPointInitializer
import zuko


class GMMInitializer(CutPointInitializer):
    """
    Fits a GMM to the data once, then initializes cut points for individual
    rules based on the GMM components in a cycling manner. Assumes input
    data X is already scaled.
    """

    def __init__(
        self,
        X: torch.Tensor,
        gmm_components: int,
        gmm_std_dev_multiplier: float = 1.0,
        gmm_iterations: int = 1000,
        gmm_lr: float = 1e-3,
        verbose: bool = False,
        **kwargs,
    ):
        """
        Fits the GMM upon initialization using pre-scaled data.

        Args:
            X: Input data tensor of shape [n_samples, n_features],
               assumed to be already scaled (e.g., standard scaled).
               Used for fitting the GMM.
            gmm_components: The number of components for the GMM.
            gmm_std_dev_multiplier: Multiplier 'c' for std deviation used to
               define interval boundaries (mean +/- c*std_dev). Defaults to 1.0.
            gmm_iterations: Number of training iterations for the GMM.
               Defaults to 1000.
            gmm_lr: Learning rate for the GMM optimizer. Defaults to 1e-3.
            verbose: If True, print GMM fitting progress. Defaults to False.

        Raises:
            ImportError: If 'zuko' is not installed.
            TypeError: If X is not a torch.Tensor.
            ValueError: If X has incorrect dimensions or gmm_std_dev_multiplier is invalid.
        """
        if not isinstance(X, torch.Tensor):
            raise TypeError("Input data 'X' must be a torch.Tensor.")
        if X.ndim != 2:
            raise ValueError(
                f"Input data 'X' must have shape [n_samples, n_features], but got {X.shape}."
            )
        if gmm_components <= 0:
            raise ValueError("gmm_components must be positive.")
        if gmm_std_dev_multiplier <= 0:
            raise ValueError("gmm_std_dev_multiplier must be positive.")

        self.gmm_components = gmm_components
        self.gmm_std_dev_multiplier = gmm_std_dev_multiplier
        self.n_features_data = X.shape[1]
        self.device = X.device  # Store device from data

        # --- Fit GMM and store results ---
        # Assumes X is already scaled
        if verbose:
            print(f"GMMInitializer: Fitting GMM with {gmm_components} components...")

        # Fit GMM using zuko directly on X (assumed scaled)
        flow = zuko.flows.GMM(
            features=self.n_features_data, components=self.gmm_components
        )
        flow.to(self.device)
        optimizer = torch.optim.Adam(flow.parameters(), lr=gmm_lr)

        flow.train()
        for i in range(gmm_iterations):
            optimizer.zero_grad()
            dist = flow()
            # Use X directly
            loss = -dist.log_prob(X).mean()
            if not torch.isfinite(loss):
                print(
                    f"Warning: GMMInitializer: Non-finite loss ({loss.item()}) encountered at iteration {i}. Stopping GMM training."
                )
                break
            loss.backward()
            optimizer.step()
            if verbose and (i + 1) % 200 == 0:
                print(
                    f"  GMM Iteration {i+1}/{gmm_iterations}, Loss: {loss.item():.4f}"
                )

        if verbose:
            print("GMMInitializer: GMM fitting complete.")
        flow.eval()

        # Extract and store parameters (already in the scaled space)
        with torch.no_grad():
            final_dist = flow()
            base_dist = final_dist.base
            means = base_dist.loc  # Shape: [gmm_components, n_features]
            cov = (
                base_dist.covariance_matrix
            )  # Shape: [gmm_components, n_features, n_features]
            variances = torch.diagonal(
                cov, dim1=-2, dim2=-1
            )  # Shape: [gmm_components, n_features]
            std_devs = torch.sqrt(
                torch.clamp(variances, min=1e-6)
            )  # Shape: [gmm_components, n_features]

            # Store parameters directly (they are already in the space of X)
            self.gmm_means = means
            self.gmm_stds = std_devs

        if verbose:
            print("GMM component means (scaled space): ", self.gmm_means)
        # --- State for cycling through components ---
        self._current_component_index = 0

    def initialize(
        self,
        n_features: int,
        predicates_per_feature: int,
        data_limits: torch.Tensor,
    ) -> torch.Tensor:
        """
        Generates initial cut points for one rule using the next GMM component's stats.
        Assumes data_limits correspond to the scaled data space used during fitting.
        """
        if n_features != self.n_features_data:
            # This check might be too strict if rules operate on subsets of features.
            # However, the current SimpleNeuralAndFinder uses all features.
            # If rules could use subsets, data_limits would need careful handling.
            raise ValueError(
                f"GMMInitializer.initialize: n_features ({n_features}) does not match "
                f"feature dimension of data used for GMM fitting ({self.n_features_data})."
            )
        if predicates_per_feature != 1:
            print(
                f"Warning: GMMInitializer.initialize called with predicates_per_feature={predicates_per_feature}. Expected 1."
            )
            # Proceed assuming predicates_per_feature=1 for tensor shape

        # Select the GMM component for this rule based on cycling
        component_idx = self._current_component_index
        self._current_component_index = (
            self._current_component_index + 1
        ) % self.gmm_components

        # Get the stats for the selected component (already in scaled space)
        # Shape: [n_features]
        mean_vec = self.gmm_means[component_idx, :].clone()
        std_vec = self.gmm_stds[component_idx, :].clone()

        # create random offset with shape [n_features]
        offset = (
            torch.rand(n_features, device=self.device, requires_grad=False) - 0.5
        ) * 0.1
        mean_vec += offset
        multiplier_offset = (
            torch.rand(1, device=self.device, requires_grad=False) - 0.5
        ) * 0.1 + 1
        # Shape: [n_features]
        starts = mean_vec - self.gmm_std_dev_multiplier * std_vec * multiplier_offset
        ends = mean_vec + self.gmm_std_dev_multiplier * std_vec * multiplier_offset

        # Ensure start is always less than end before clamping
        actual_starts = torch.min(starts, ends)
        actual_ends = torch.max(starts, ends)

        # Clamp the calculated interval boundaries to the provided data_limits
        # data_limits shape: [n_features, 2] (assumed to be in scaled space)
        min_limits = data_limits[:, 0].to(self.device)  # Shape: [n_features]
        max_limits = data_limits[:, 1].to(self.device)  # Shape: [n_features]

        clamped_starts = torch.clamp(actual_starts, min=min_limits, max=max_limits)
        clamped_ends = torch.clamp(actual_ends, min=min_limits, max=max_limits)

        # Create the initialization tensor [n_features, 2, 1]
        init = torch.zeros(n_features, 2, 1, dtype=torch.float32, device=self.device)

        # Assign clamped values, adding the last dimension
        init[:, 0, 0] = clamped_starts
        init[:, 1, 0] = clamped_ends

        # Final check: ensure start < end after clamping
        mask_start_ge_end = init[:, 0, 0] >= init[:, 1, 0]
        epsilon = 1e-6
        # Adjust ends first: end = min(start + eps, max_limit)
        adjusted_ends = init[:, 0, 0] + epsilon
        init[:, 1, 0][mask_start_ge_end] = torch.min(
            adjusted_ends[mask_start_ge_end],
            max_limits[mask_start_ge_end],  # Use corresponding max limit
        )
        # Then adjust starts if needed: start = min(start, new_end - eps)
        init[:, 0, 0][mask_start_ge_end] = torch.min(
            init[:, 0, 0][mask_start_ge_end], init[:, 1, 0][mask_start_ge_end] - epsilon
        )
        # Ensure start >= min_limit after potential adjustment
        init[:, 0, 0][mask_start_ge_end] = torch.max(
            init[:, 0, 0][mask_start_ge_end],
            min_limits[mask_start_ge_end],  # Use corresponding min limit
        )

        # If predicates_per_feature > 1 was passed, expand the result
        if predicates_per_feature > 1:
            init = init.expand(n_features, 2, predicates_per_feature)

        return init
