import matplotlib.pyplot as plt
import itertools
import numpy as np
import torch
import torch.distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily

from .base_set import BaseSet


class DistortedGaussianMixture(BaseSet):
    def __init__(self, device, distortion_coef=0.1, dim=2):
        super().__init__()
        self.data = torch.tensor([0.0])
        self.device = device

        gen = torch.Generator(device=self.device)
        gen.manual_seed(42)

        assert 0 <= distortion_coef and distortion_coef <= 1

        self.nmode = 5**dim
        # Create a grid of means in N-dimensional space
        grid_values = [-10, -5, 0, 5, 10]
        grid_points = [grid_values] * dim  # Create N copies of grid    values
        self.means = torch.Tensor(list(itertools.product(*grid_points))).to(self.device)
        # self.means += distortion_coef * (3 * torch.randn(self.means.shape, device=self.device, generator=gen))

        self.data_ndim = dim

        self.covariance_matrices = []
        for i in range(self.nmode):
            L = torch.linalg.cholesky(0.3 * torch.eye(dim, device=self.device))
            noise = distortion_coef * torch.randn((dim, dim), device=self.device, generator=gen)
            L1 = (L + noise).T @ (L + noise)
            self.covariance_matrices.append(L1)

        self.gmm_weights = np.ones(self.nmode) / self.nmode

        self.gmm = [
            D.MultivariateNormal(loc=mode, covariance_matrix=covariance_matrix)
            for mode, covariance_matrix in zip(self.means, self.covariance_matrices)
        ]
        self.mode_sampler = D.Categorical(torch.ones(self.nmode) / self.nmode)

    @property
    def gt_logz(self):
        return 0.0

    def energy(self, x):
        log_prob = torch.logsumexp(torch.stack([mvn.log_prob(x) for mvn in self.gmm]), dim=0, keepdim=False) - torch.log(
            torch.tensor(self.nmode, device=self.device)
        )
        return -log_prob

    def sample(self, batch_size):
        modes = self.mode_sampler.sample((batch_size,))
        samples = torch.cat([self.gmm[mode_idx].sample(((modes == mode_idx).sum().item(),)) for mode_idx in range(self.nmode)], dim=0).to(
            self.device
        )
        return samples

    def viz_pdf(self, fsave="distorted_gmm_density.png"):
        x = torch.linspace(-20, 20, 100).to(self.device)
        y = torch.linspace(-20, 20, 100).to(self.device)
        X, Y = torch.meshgrid(x, y)
        x = torch.stack([X.flatten(), Y.flatten()], dim=1)  # ?

        density = self.unnorm_pdf(x)
        return x, density

    def __getitem__(self, idx):
        del idx
        return self.data[0]
