import torch


class IncrementalPCAModule(torch.nn.Module):
    """
    An implementation of Incremental Principal Components Analysis (IPCA) that leverages PyTorch for GPU acceleration.

    This class provides methods to fit the model on data incrementally in batches, and to transform new data
    based on the principal components learned during the fitting process.

    Attributes:
        n_components (int, optional): Number of components to keep. If `None`, it's set to the minimum of the
                                      number of samples and features. Defaults to None.
        copy (bool): If False, input data will be overwritten. Defaults to True.
        batch_size (int, optional): The number of samples to use for each batch. If `None`, it's inferred from
                                    the data and set to `5 * n_features`. Defaults to None.
        svd_driver (str, optional): name of the cuSOLVER method to be used for svd. This keyword argument only
                                works on CUDA inputs. Available options are: None, gesvd, gesvdj, and gesvda.
                                Defaults to None.
    """

    def __init__(
        self, n_components, n_features, copy=True, svd_driver=None
    ):
        super().__init__()
        self.n_components = n_components
        self.n_features = n_features
        self.copy = copy
        self.svd_driver = svd_driver

        # initialize buffers
        self.register_buffer("mean_", torch.empty(n_features, dtype=torch.float64))
        self.register_buffer("var_", torch.empty(n_features, dtype=torch.float64))
        self.register_buffer("n_samples_seen_", torch.zeros(1, dtype=torch.int64))
        self.register_buffer("components_", torch.empty(n_components, n_features, dtype=torch.float32))
        self.register_buffer("singular_values_", torch.empty(n_components, dtype=torch.float32))
        self.register_buffer("explained_variance_", torch.empty(n_components, dtype=torch.float32))
        self.register_buffer("explained_variance_ratio_", torch.empty(n_components, dtype=torch.float32))
        self.register_buffer("noise_variance_", torch.empty(1, dtype=torch.float32))


    def _validate_data(self, X, dtype=torch.float32):
        """
        Validates and converts the input data `X` to the appropriate tensor format.

        This method ensures that the input data is in the form of a PyTorch tensor and resides on the correct device (CPU or GPU).
        It also provides an option to create a copy of the tensor, which is useful when the input data should not be overwritten.

        Args:
            X (torch.Tensor): Input data.
            dtype (torch.dtype, optional): Desired data type for the tensor. Defaults to torch.float32.
            copy (bool, optional): Whether to clone the tensor. If True, a new tensor is returned; otherwise, the original tensor
                                   (or its device-transferred version) is returned. Defaults to True.

        Returns:
            torch.Tensor: Validated and possibly copied tensor residing on the specified device.
        """
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=dtype)
        elif self.copy:
            X = X.clone()

        if X.dtype != dtype:
            X = X.to(dtype)

        return X

    @staticmethod
    def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
        """
        Computes the incremental mean and variance for the data `X`.

        Args:
            X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
            last_mean (torch.Tensor): The previous mean tensor with shape (n_features,).
            last_variance (torch.Tensor): The previous variance tensor with shape (n_features,).
            last_sample_count (torch.Tensor): The count tensor of samples processed before the current batch.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, int]: Updated mean, variance tensors, and total sample count.
        """
        if X.shape[0] == 0:
            return last_mean, last_variance, last_sample_count

        if last_sample_count > 0:
            assert (
                last_mean is not None
            ), "last_mean should not be None if last_sample_count > 0."
            assert (
                last_variance is not None
            ), "last_variance should not be None if last_sample_count > 0."

        new_sample_count = torch.tensor([X.shape[0]], device=X.device)
        updated_sample_count = last_sample_count + new_sample_count

        if last_mean is None:
            last_sum = torch.zeros(X.shape[1], dtype=torch.float64, device=X.device)
        else:
            last_sum = last_mean * last_sample_count

        new_sum = X.sum(dim=0, dtype=torch.float64)

        updated_mean = (last_sum + new_sum) / updated_sample_count

        T = new_sum / new_sample_count
        temp = X - T
        correction = temp.sum(dim=0, dtype=torch.float64).square()
        temp.square_()
        new_unnormalized_variance = temp.sum(dim=0, dtype=torch.float64)
        new_unnormalized_variance -= correction / new_sample_count
        if last_variance is None:
            updated_variance = new_unnormalized_variance / updated_sample_count
        else:
            last_unnormalized_variance = last_variance * last_sample_count
            last_over_new_count = last_sample_count.double() / new_sample_count
            updated_unnormalized_variance = (
                last_unnormalized_variance
                + new_unnormalized_variance
                + last_over_new_count
                / updated_sample_count
                * (last_sum / last_over_new_count - new_sum).square()
            )
            updated_variance = updated_unnormalized_variance / updated_sample_count

        return updated_mean, updated_variance, updated_sample_count

    @staticmethod
    def _svd_flip(u, v, u_based_decision=True):
        """
        Adjusts the signs of the singular vectors from the SVD decomposition for deterministic output.

        This method ensures that the output remains consistent across different runs.

        Args:
            u (torch.Tensor): Left singular vectors tensor.
            v (torch.Tensor): Right singular vectors tensor.
            u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping. Defaults to True.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
        """
        if u_based_decision:
            max_abs_cols = torch.argmax(torch.abs(u), dim=0)
            signs = torch.sign(u[max_abs_cols, range(u.shape[1])])
        else:
            max_abs_rows = torch.argmax(torch.abs(v), dim=1)
            signs = torch.sign(v[range(v.shape[0]), max_abs_rows])
        u *= signs.view(1, -1)
        v *= signs.view(-1, 1)
        return u, v

    def forward(self, X, check_input=True):
        """
        Incrementally fits the model with batch data `X`.

        Args:
            X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
            check_input (bool, optional): If True, validates the input. Defaults to True.

        Returns:
            IncrementalPCAGPU: The updated IPCA model after processing the batch.
        """
        first_pass = not hasattr(self, "components_")

        if check_input:
            X = self._validate_data(X)
        n_samples, n_features = X.shape

        col_mean, col_var, n_total_samples = self._incremental_mean_and_var(
            X, self.mean_, self.var_, self.n_samples_seen_
        )

        if first_pass:
            X -= col_mean
        else:
            col_batch_mean = torch.mean(X, dim=0)
            X -= col_batch_mean
            mean_correction_factor = torch.sqrt(
                (self.n_samples_seen_.double() / n_total_samples) * n_samples
            )
            mean_correction = mean_correction_factor * (self.mean_ - col_batch_mean)
            X = torch.vstack(
                (
                    self.singular_values_.view((-1, 1)) * self.components_,
                    X,
                    mean_correction,
                )
            )

        U, S, Vt = torch.linalg.svd(X, full_matrices=False, driver=self.svd_driver)
        U, Vt = self._svd_flip(U, Vt, u_based_decision=False)
        explained_variance = S**2 / (n_total_samples - 1)
        explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples)

        self.n_samples_seen_.copy_(n_total_samples)
        self.components_.copy_(Vt[:self.n_components])
        self.singular_values_.copy_(S[:self.n_components])
        self.mean_.copy_(col_mean)
        self.var_.copy_(col_var)
        self.explained_variance_ .copy_(explained_variance[:self.n_components])
        self.explained_variance_ratio_.copy_(explained_variance_ratio[:self.n_components])
        if self.n_components not in (n_samples, n_features):
            exp_var = explained_variance[self.n_components :].mean().view(-1)
            self.noise_variance_.copy_(exp_var)
        else:
            self.noise_variance_[0] = 0
        return self

    def transform(self, X):
        """
        Applies dimensionality reduction to `X`.

        The input data `X` is projected on the first principal components previously extracted from a training set.

        Args:
            X (torch.Tensor): New data tensor with shape (n_samples, n_features) to be transformed.

        Returns:
            torch.Tensor: Transformed data tensor with shape (n_samples, n_components).
        """
        X -= self.mean_
        return torch.mm(X, self.components_.T)


class LinearWithPCA(torch.nn.Module):
    def __init__(self, linear, early_stop_p, early_stop_sim_thresh, n_components, *args, **kwargs):
        super().__init__()
        self.linear = linear
        self.early_stop_p = early_stop_p
        self.early_stop_sim_thresh = early_stop_sim_thresh
        self.pca = IncrementalPCAModule(n_components, linear.in_features, *args, **kwargs)
        self.sim = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
        self.pca_has_converged = False
        self.pca_n_converged = 0
        
        self.register_buffer("previous_components_", torch.full((n_components, linear.in_features), fill_value=float("-inf"), dtype=torch.float32))

    def has_converged_check(self):
        sim = self.sim(self.pca.components_, self.previous_components_)
        self.pca_n_converged = (sim >= self.early_stop_sim_thresh).sum()
        self.pca_has_converged = self.pca_n_converged >= self.early_stop_p * len(sim)

    def forward(self, x):
        if not self.pca_has_converged:
            self.pca(x.to(torch.float32))
            self.has_converged_check()
        x = self.linear(x)
        return x