import matplotlib.pyplot as plt

import torch
import torch.distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily

from .base_set import BaseSet


class TwentyFiveGaussianMixture(BaseSet):
    def __init__(self, device, dim=2):
        super().__init__()
        self.data = torch.tensor([0.0])
        self.device = device

        modes = torch.Tensor([(a, b) for a in [-10, -5, 0, 5, 10] for b in [-10, -5, 0, 5, 10]]).to(self.device)
        self.covariance_matrix = 0.3 * torch.eye(self.data_ndim, device=self.device)

        nmode = 25
        self.nmode = nmode

        self.data_ndim = dim

        self.gmm = [D.MultivariateNormal(loc=mode, covariance_matrix=self.covariance_matrix) for mode in modes]
        self.mode_sampler = D.Categorical(torch.ones(self.nmode) / self.nmode)

    @property
    def gt_logz(self):
        return 0.0

    def energy(self, x):
        log_prob = torch.logsumexp(
            torch.stack([mvn.log_prob(x) for mvn in self.gmm]), dim=0, keepdim=False
        ) - torch.log(torch.tensor(self.nmode, device=self.device))
        return -log_prob

    def sample(self, batch_size):
        modes = self.mode_sampler.sample((batch_size,))
        samples = torch.cat(
            [self.gmm[mode_idx].sample(((modes == mode_idx).sum().item(),)) for mode_idx in range(self.nmode)], dim=0
        ).to(self.device)
        return samples

    def viz_pdf(self, fsave="25gmm-density.png"):
        x = torch.linspace(-15, 15, 100).to(self.device)
        y = torch.linspace(-15, 15, 100).to(self.device)
        X, Y = torch.meshgrid(x, y)
        x = torch.stack([X.flatten(), Y.flatten()], dim=1)  # ?

        density = self.unnorm_pdf(x)
        return x, density

    def __getitem__(self, idx):
        del idx
        return self.data[0]
