# This script is originally generated by Gemini-2.0 Experimental Advanced
# With some minor optimizations
import torch
import torch.nn.functional as F

from ..utils import logger
from .base import _BatchFitMixin


class GaussianMixtureEM(_BatchFitMixin):
    def __init__(
        self,
        n_components,
        n_features,
        max_iter=100,
        tol=1e-7,
        learnable_covariance=False,
        scale: float = 1.0,
        verbose: bool = False,
    ):
        self.n_components = n_components
        self.n_features = n_features
        self.max_iter = max_iter
        self.tol = tol
        self.learnable_covariance = learnable_covariance
        # Currently some related ops are not supported in ``mps`` mode, use cpus as
        # EM approaches are typically computationally light
        self.device = torch.device("cpu")
        self.scale = scale
        self.verbose = verbose

    def _initialize_params(self):
        # Initialize parameters
        weights = (
            torch.ones(self.n_components, device=self.device) / self.n_components
        )  # Mixture weights
        means = torch.randn(
            self.n_components, self.n_features, device=self.device
        )  # Component means
        covariances = torch.zeros(
            self.n_components, self.n_features, self.n_features, device=self.device
        )  # Component covariances

        # Initialize covariances to be diagonal (for simplicity, but you can make it full)
        for k in range(self.n_components):
            covariances[k] = torch.eye(self.n_features, device=self.device) * self.scale
        return means, covariances, weights

    def _e_step(self, X, means, covariances, weights):
        """Expectation step: Calculate responsibilities."""
        n_samples = X.shape[0]
        log_responsibilities = torch.zeros(
            n_samples, self.n_components, device=self.device
        )

        for k in range(self.n_components):
            # Calculate the probability density of each data point under each Gaussian
            diff = X - means[k]
            exponent = -0.5 * torch.sum(
                (diff @ torch.inverse(covariances[k])) * diff, dim=1
            )
            log_det = torch.logdet(2 * torch.pi * covariances[k])
            log_prob_density = exponent - 0.5 * log_det

            log_responsibilities[:, k] = torch.log(weights[k]) + log_prob_density

        # Normalize responsibilities so they sum to 1 for each data point
        responsibilities = F.softmax(log_responsibilities, dim=1)
        return responsibilities

    def _m_step(self, X, responsibilities, means, covariances, weights):
        """Maximization step: Update parameters based on responsibilities."""
        n_samples = X.shape[0]
        effective_samples = torch.sum(responsibilities, dim=0)

        # Update weights
        weights = effective_samples / n_samples

        # Update means
        for k in range(self.n_components):
            means[k] = torch.sum(responsibilities[:, k].view(-1, 1) * X, dim=0) / (
                effective_samples[k] + 1e-15
            )

        if self.learnable_covariance:
            # Update covariances
            for k in range(self.n_components):
                diff = X - means[k]
                covariances[k] = (
                    diff.T @ (responsibilities[:, k].view(-1, 1) * diff)
                ) / (effective_samples[k] + 1e-15)
                # Only updates the diagonal part in anisotropic cases
                covariances[k] = covariances[k] * torch.eye(
                    self.n_features, device=self.device
                )
                covariances[k] += 1e-6 * torch.eye(self.n_features, device=self.device)
        return means, covariances, weights

    def fit(self, X, *args, **kwargs):
        """Fit the GMM to the data."""
        X = X.to(self.device)
        log_likelihood_history = []

        means, covariances, weights = self._initialize_params()

        iteration = 0
        for iteration in range(self.max_iter):
            # E-step
            responsibilities = self._e_step(X, means, covariances, weights)

            # M-step
            means, covariances, weights = self._m_step(
                X, responsibilities, means, covariances, weights
            )

            # Calculate log-likelihood for convergence check
            log_likelihood = self.compute_log_likelihood(X, means, covariances, weights)
            log_likelihood_history.append(log_likelihood)

            # Check for convergence
            if (
                iteration > 0
                and torch.abs(log_likelihood_history[-1] - log_likelihood_history[-2])
                < self.tol
            ):
                if self.verbose:
                    logger.info(f"Converged after {iteration} iterations.")
                break

            if iteration == self.max_iter - 1 and self.verbose:
                logger.warn(f"Reach maximum iterations: {iteration + 1}.")

        # Stabilize computation later
        covariances = covariances @ torch.ones(self.n_features, device=self.device)
        return weights, means, covariances, torch.tensor(iteration, dtype=torch.float)

    def compute_log_likelihood(self, X, means, covariances, weights):
        """Compute the log-likelihood of the data under the current model."""
        n_samples = X.shape[0]
        log_likelihood = 0

        for k in range(self.n_components):
            diff = X - means[k]

            # More stable log determinant calculation
            log_det = torch.logdet(2 * torch.pi * covariances[k])

            exponent = -0.5 * torch.sum(
                (diff @ torch.inverse(covariances[k])) * diff, dim=1
            )
            log_prob_density = exponent - 0.5 * log_det
            log_likelihood += torch.logsumexp(
                torch.log(weights[k]) + log_prob_density, dim=0
            )

        return log_likelihood / n_samples
