import torch
import torch.distributions as D

from .base_set import BaseSet


class TwoCloseGaussianMixture(BaseSet):

    def __init__(self, device):
        super().__init__()
        self.data = torch.tensor([0.0])
        self.device = device

        self.nmode = 2
        self.data_ndim = 2
        self.modes = torch.tensor([[0, 3], [0, -3]], dtype=torch.float).to(device)
        small_cov, big_cov = 0.3, 1
        self.covariance_matrices = [
            torch.diag(torch.tensor([big_cov, small_cov], dtype=torch.float)).to(device),
            torch.diag(torch.tensor([small_cov, big_cov], dtype=torch.float)).to(device),
        ]

        self.gmm = [
            D.MultivariateNormal(loc=mode, covariance_matrix=covariance_matrix)
            for mode, covariance_matrix in zip(self.modes, self.covariance_matrices)
        ]
        # self.mode_sampler = D.categorical.Categorical(torch.ones(self.nmode) / self.nmode)
        self.mode_sampler = D.Categorical(torch.ones(self.nmode) / self.nmode)

    @property
    def gt_logz(self):
        return 0.0

    def energy(self, x):
        log_prob = torch.logsumexp(
            torch.stack([mvn.log_prob(x) for mvn in self.gmm]), dim=0, keepdim=False
        ) - torch.log(torch.tensor(self.nmode, device=self.device))
        return -log_prob

    def sample(self, batch_size):
        modes = self.mode_sampler.sample((batch_size,))
        samples = torch.cat(
            [self.gmm[mode_idx].sample(((modes == mode_idx).sum().item(),)) for mode_idx in range(self.nmode)], dim=0
        ).to(self.device)
        return samples

    def viz_pdf(self, fsave="2cgmm-density.png"):
        x = torch.linspace(-15, 15, 100).to(self.device)
        y = torch.linspace(-15, 15, 100).to(self.device)
        X, Y = torch.meshgrid(x, y)
        x = torch.stack([X.flatten(), Y.flatten()], dim=1)  # ?

        density = self.unnorm_pdf(x)
        return x, density

    def __getitem__(self, idx):
        del idx
        return self.data[0]
