import matplotlib.pyplot as plt

import torch
import torch.distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily

from .base_set import BaseSet


class RandomGaussianMixture(BaseSet):

    def __init__(self, device, nmode=16, dim=2, seed=4):
        super().__init__()
        self.data = torch.tensor([0.0])
        self.device = device

        gen = torch.Generator()
        gen.manual_seed(seed)

        self.nmode = nmode
        self.data_ndim = dim
        self.modes = 15 * (torch.rand((nmode, dim), generator=gen) - 0.5).to(device)
        self.covariance_matrices = [
            torch.diag(torch.randn(dim, generator=gen).abs().to(self.device) * 0.3 + 0.05).to(device)
            for _ in range(nmode)
        ]

        self.gmm = [
            D.MultivariateNormal(loc=mode, covariance_matrix=covariance_matrix)
            for mode, covariance_matrix in zip(self.modes, self.covariance_matrices)
        ]
        self.mode_sampler = D.Categorical(torch.ones(self.nmode) / self.nmode)

    @property
    def gt_logz(self):
        return 0.0

    def energy(self, x):
        log_prob = torch.logsumexp(
            torch.stack([mvn.log_prob(x) for mvn in self.gmm]), dim=0, keepdim=False
        ) - torch.log(torch.tensor(self.nmode, device=self.device))
        return -log_prob

    def sample(self, batch_size):
        modes = self.mode_sampler.sample((batch_size,))
        samples = torch.cat(
            [self.gmm[mode_idx].sample(((modes == mode_idx).sum().item(),)) for mode_idx in range(self.nmode)], dim=0
        ).to(self.device)
        return samples

    def viz_pdf(self, fsave="rgmm-density.png"):
        x = torch.linspace(-15, 15, 100).to(self.device)
        y = torch.linspace(-15, 15, 100).to(self.device)
        X, Y = torch.meshgrid(x, y)
        x = torch.stack([X.flatten(), Y.flatten()], dim=1)  # ?

        density = self.unnorm_pdf(x)
        return x, density

    def __getitem__(self, idx):
        del idx
        return self.data[0]
