import torch
import matplotlib.pyplot as plt

import torch
import matplotlib.pyplot as plt


class GaussianMixtureDenoisingModel(torch.nn.Module):
    def __init__(self, num_steps=1000, epsilon=0.008, device='cuda'):
        super().__init__()
        self.device = torch.device(device)
        self.num_steps = num_steps
        self.betas = self.cosine_variance_schedule(num_steps, epsilon).to(self.device)
        self.alphas = (1.0 - self.betas).to(self.device)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(self.device)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(self.device)

        self.clusters = [
            {"center": (-0.0, -0.6),    "std": 0.06, "prob": 0.989},
            {"center": (-0.6, 0.6),    "std": 0.06, "prob": 0.01},
            {"center": (0.6, 0.6),   "std": 0.06, "prob": 0.001}
        ]





        for cluster in self.clusters:
            cluster["center"] = torch.tensor(cluster["center"], dtype=torch.float32, device=device).unsqueeze(0)
            cluster["std"] = torch.tensor(cluster["std"], dtype=torch.float32, device=device)
            cluster["prob"] = torch.tensor(cluster["prob"], dtype=torch.float32, device=device)

    def cosine_variance_schedule(self, num_steps, epsilon=0.008):
        steps = torch.linspace(0, num_steps, num_steps + 1, dtype=torch.float32, device=self.device)
        f_t = torch.cos(((steps / num_steps + epsilon) / (1.0 + epsilon)) * torch.pi * 0.5) ** 2
        betas = torch.clamp(1.0 - f_t[1:] / f_t[:num_steps], 0.0, 0.999)
        return betas

    def compute_mog_params(self, t_index):
        alpha_bar_t = torch.gather(self.alphas_cumprod, 0, t_index).view(-1, 1,
                                                                         1)  # Ensure correct broadcasting for batch size
        new_means = [c["center"].expand(alpha_bar_t.shape[0], -1, -1) * torch.sqrt(alpha_bar_t) for c in self.clusters]
        new_covs = [
            (torch.eye(2, device=self.device) * c["std"] ** 2).unsqueeze(0).expand(alpha_bar_t.shape[0], -1, -1) *
            alpha_bar_t + (1 - alpha_bar_t) * torch.eye(2, device=self.device).unsqueeze(0).expand(alpha_bar_t.shape[0],
                                                                                                   -1, -1)
            for c in self.clusters]
        return new_means, new_covs

    def score_function_batch(self, X, t):
        new_means, new_covs = self.compute_mog_params(t)
        batch_size = X.shape[0]
        component_scores = torch.zeros(len(self.clusters), batch_size, X.shape[1], device=self.device)
        responsibilities = torch.zeros(len(self.clusters), batch_size, device=self.device)
        density_sum = torch.zeros(batch_size, device=self.device)

        for i in range(len(self.clusters)):
            mean_i, cov_i = new_means[i].squeeze(1), new_covs[i]
            inv_cov_i = torch.inverse(cov_i)
            gauss_i = torch.distributions.MultivariateNormal(mean_i, cov_i)
            density_i = gauss_i.log_prob(X).exp()
            prob_i = self.clusters[i]["prob"].expand(density_i.shape)
            density_sum += prob_i * density_i

            score_i = -torch.einsum('bi,bij->bj', (X - mean_i), inv_cov_i)
            #score_i = -torch.matmul((X - mean_i), inv_cov_i.transpose(-2, -1))
            component_scores[i] = score_i
            responsibilities[i] = prob_i * density_i

        responsibilities /= (density_sum.unsqueeze(0) + 1e-8)  # Adjusted broadcasting
        score = torch.sum(responsibilities[:, :, None] * component_scores, dim=0)
        return score

    def forward(self, x_t, t):
        t = t.to(self.device).view(-1)  # Ensure t has batch dimension
        sqrt_one_minus_alpha_cumprod_t = torch.gather(self.sqrt_one_minus_alphas_cumprod, 0, t).view(-1, 1)
        pred_noise = -self.score_function_batch(x_t, t) * sqrt_one_minus_alpha_cumprod_t
        return pred_noise


# Visualization and testing
if __name__ == "__main__":
    model = GaussianMixtureDenoisingModel(device='cuda')
    num_samples = 10000
    x_t = torch.randn(num_samples, 2, device=model.device)
    vis_steps = torch.linspace(0, model.num_steps - 1, 5, dtype=torch.int32, device=model.device)
    reverse_samples = []

    for step in range(model.num_steps - 1, -1, -1):
        t = torch.full((num_samples,), step, dtype=torch.int64, device=model.device)
        beta_t = model.betas[t].unsqueeze(-1)
        alpha_cumprod_t = model.alphas_cumprod[t].unsqueeze(-1)
        alpha_cumprod_t_1 = model.alphas_cumprod[t-1].unsqueeze(-1)
        sqrt_alpha_cumprod_t = model.sqrt_alphas_cumprod[t].unsqueeze(-1)
        sqrt_one_minus_alpha_cumprod_t = model.sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1)
        alpha_t = model.alphas[t].unsqueeze(-1)

        pred_noise = model.forward(x_t, t)
        x_0_pred = torch.sqrt(1 / alpha_cumprod_t) * x_t - torch.sqrt(1 / alpha_cumprod_t - 1) * pred_noise
        x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)

        if step > 0:
            alpha_cumprod_t_prev = alpha_cumprod_t_1 #model.alphas_cumprod[t - 1]
            mean = (beta_t * torch.sqrt(alpha_cumprod_t_prev) / (1 - alpha_cumprod_t)) * x_0_pred + \
                   ((1 - alpha_cumprod_t_prev) * torch.sqrt(alpha_t) / (1 - alpha_cumprod_t)) * x_t
            std = torch.sqrt(beta_t * (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t))
        else:
            mean = (beta_t / (1 - alpha_cumprod_t)) * x_0_pred
            std = 0.0

        noise = torch.randn(num_samples, 2, device=model.device) * std if step > 0 else 0
        x_t = mean + noise

        if step in vis_steps:
            reverse_samples.append(x_t.clone().cpu())

    fig, axes = plt.subplots(1, 6, figsize=(18, 3))
    for i, ax in enumerate(axes[:-1]):
        ax.scatter(reverse_samples[i][:, 0].numpy(), reverse_samples[i][:, 1].numpy(), alpha=0.5, s=10)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_title(f"t = {model.num_steps - vis_steps[i]}")
    axes[-1].scatter(x_t.cpu()[:, 0].numpy(), x_t.cpu()[:, 1].numpy(), alpha=0.5, s=10, color='red')
    axes[-1].set_xlim(-1, 1)
    axes[-1].set_ylim(-1, 1)
    axes[-1].set_title("Final Samples")
    plt.tight_layout()
    plt.show()
