import numpy as np


class GaussianMixtureDistribution:
    def __init__(self, mu, sigma, pi):
        self.means = np.array(mu)
        self.stds = np.array(sigma)
        self.pi = np.array(pi)

    def sample(self, N):
        z = np.random.choice(len(self.pi), size=N, p=self.pi)
        samples = np.zeros(N)
        for i in range(len(self.pi)):
            samples[z == i] = np.random.normal(
                self.means[i], self.stds[i], size=np.sum(z == i)
            )
        return samples

    def log_prob(self, x):
        log_probs = np.zeros((len(x), len(self.pi)))
        for i in range(len(self.pi)):
            log_probs[:, i] = (
                np.log(self.pi[i])
                + np.log(1 / (self.stds[i] * np.sqrt(2 * np.pi)))
                - 0.5 * ((x - self.means[i]) / self.stds[i]) ** 2
            )
        return log_probs.sum(axis=1)

    def pdf(self, x):
        pdf_values = np.zeros((len(x), len(self.pi)))
        for i in range(len(self.pi)):
            pdf_values[:, i] = (1 / (self.stds[i] * np.sqrt(2 * np.pi))) * np.exp(
                -0.5 * ((x - self.means[i]) / self.stds[i]) ** 2
            )
        return pdf_values.dot(self.pi)

    def gaussian_perturbation(self, accumulated_drift, var):
        """function to calculate marginal p_t"""
        perturbed_means = accumulated_drift * self.means
        perturbed_stds = np.sqrt(accumulated_drift**2 * self.stds**2 + var)
        return GaussianMixtureDistribution(perturbed_means, perturbed_stds, self.pi)

    def score(self, x):
        """Calculate the score function for the distribution"""
        gamma = np.zeros((len(x), len(self.pi)))
        for k in range(len(self.pi)):
            gamma[:, k] = (self.pi[k] / (self.stds[k] * np.sqrt(2 * np.pi))) * np.exp(
                -0.5 * ((x - self.means[k]) / self.stds[k]) ** 2
            )

        gamma_sum = gamma.sum(axis=1, keepdims=True)
        gamma /= gamma_sum

        for k in range(len(self.pi)):
            gamma[:, k] *= -(x - self.means[k]) / (self.stds[k] ** 2)

        return gamma.sum(axis=1)

    def var(self, N):
        samples = self.sample(N)
        return np.var(samples)

    def stochastic_interpolation(self, k, sigma, other):
        """
        Computes q(x_t) = ∫∫ N(x_t | k*x0 + (1-k)*x1, sigma^2) dp0(x0) dp1(x1)
        where `self` is p0, and `other` is p1.

        Returns:
            GaussianMixtureDistribution representing q(x_t)
        """
        mu_0 = self.means
        sigma_0 = self.stds
        pi_0 = self.pi

        mu_1 = other.means
        sigma_1 = other.stds
        pi_1 = other.pi

        new_means = []
        new_stds = []
        new_weights = []

        for i in range(len(mu_0)):
            for j in range(len(mu_1)):
                mean = k * mu_0[i] + (1 - k) * mu_1[j]
                var = (
                    (k**2) * (sigma_0[i] ** 2)
                    + ((1 - k) ** 2) * (sigma_1[j] ** 2)
                    + sigma**2
                )
                weight = pi_0[i] * pi_1[j]

                new_means.append(mean)
                new_stds.append(np.sqrt(var))
                new_weights.append(weight)

        # Normalize weights to sum to 1 (in case of rounding error)
        new_weights = np.array(new_weights)
        new_weights /= np.sum(new_weights)

        return GaussianMixtureDistribution(new_means, new_stds, new_weights)
