"""A module for a mixture density network layer."""
import torch
import torch.nn as nn
from torch.distributions import Categorical

from spirl.modules.variational_inference import MultivariateGaussian
from spirl.utils.pytorch_utils import ten2ar


class MDN(nn.Module):
    """A mixture density network layer"""
    def __init__(self, input_size, output_size, num_gaussians):
        super(MDN, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.num_gaussians = num_gaussians
        self.pi = nn.Sequential(
            nn.Linear(input_size, num_gaussians),
            nn.Softmax(dim=1)
        )
        self.log_sigma = nn.Linear(input_size, output_size * num_gaussians)
        self.mu = nn.Linear(input_size, output_size * num_gaussians)

    def forward(self, inputs):
        return torch.clamp(self.pi(inputs), min=1e-6), \
               self.mu(inputs).reshape(-1, self.num_gaussians, self.output_size), \
               torch.clamp(self.log_sigma(inputs).reshape(-1, self.num_gaussians, self.output_size), -10, 2)


class GMM:
    """Gaussian Mixture Model class."""
    def __init__(self, pi, mu=None, log_sigma=None):
        if mu is None and log_sigma is None:
            if isinstance(pi, tuple):
                pi, mu, log_sigma = pi      # in case inputs are passed in as tuple
            else:
                pi, mu, log_sigma = self.tensor2gmm(pi)
        self.pi = pi
        self.mu = mu
        self.log_sigma = log_sigma
        self._components = [MultivariateGaussian(mu[..., idx, :], log_sigma[..., idx, :]) for idx in range(mu.shape[-2])]

    def nll(self, x):
        return -1 * self.log_prob(x)

    def log_prob(self, x):
        return torch.logsumexp(torch.log(self.pi) +
                               MultivariateGaussian(self.mu, self.log_sigma).log_prob(x[:, None]), dim=1)

    def sample(self):
        """Differentiable sampling function."""
        return (MultivariateGaussian(self.mu, self.log_sigma).sample() * 
                torch.nn.functional.one_hot(Categorical(self.pi).sample(),
                                            num_classes=self.pi.shape[-1])[..., None].float()).sum(dim=1)

    def rsample(self):
        return self.sample()

    def entropy(self):
        """!!! This is not the true entropy of the GMM (there is no closed form) but only an indicator. !!!"""
        return torch.stack([c.entropy() for c in self._components], dim=1)

    def detach(self):
        return GMM(self.pi.detach(), self.mu.detach(), self.log_sigma.detach())

    def tensor(self):
        """Returns flat tensor representation of GMM."""
        return torch.cat((self.pi, self.mu.flatten(start_dim=1), self.log_sigma.flatten(start_dim=1),
                          self.pi.shape[1] * torch.ones((self.pi.shape[0], 1), device=self.pi.device)), dim=-1)

    @staticmethod
    def tensor2gmm(tensor):
        """Unwraps flattened tensor representation generated by tensor() function."""
        num_gaussians = tensor[0, -1].long()
        nz = (tensor.shape[1] - 1 - num_gaussians) / num_gaussians / 2
        pi = tensor[:, :num_gaussians]
        mu = tensor[:, num_gaussians : num_gaussians + (num_gaussians*nz)].reshape(-1, num_gaussians, nz)
        log_sigma = tensor[:, -(num_gaussians*nz + 1) : -1].reshape(-1, num_gaussians, nz)
        return pi, mu, log_sigma

    def to_numpy(self):
        """Convert internal variables to numpy arrays."""
        return GMM(ten2ar(self.pi), ten2ar(self.mu), ten2ar(self.log_sigma))

    @staticmethod
    def stack(*argv, dim):
        return GMM._combine(torch.stack, *argv, dim=dim)

    @staticmethod
    def cat(*argv, dim):
        return GMM._combine(torch.cat, *argv, dim=dim)

    @staticmethod
    def _combine(fcn, *argv, dim):
        pi, mu, log_sigma = [], [], []
        for g in argv:
            pi.append(g.pi); mu.append(g.mu); log_sigma.append(g.log_sigma)
        pi, mu, log_sigma = fcn(pi, dim), fcn(mu, dim), fcn(log_sigma, dim)
        return GMM(pi, mu, log_sigma)

    def __getitem__(self, item):
        return GMM(self.pi[item], self.mu[item], self.log_sigma[item])

    def __iter__(self):
        for pi, c in zip(self.pi, self._components):
            yield pi, c


if __name__ == "__main__":
    ### VISUALIZE
    # from spirl.utils.pytorch_utils import ten2ar
    from spirl.utils.general_utils import split_along_axis
    # import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib.patches import Ellipse

    # gmm = GMM(torch.rand((1, 5)), torch.tensor([[[0., 0], [1, 1], [1, -1], [-1, 1], [-1, -1]]]),
    #           torch.tensor([[[-1, -0.3], [-2, -1], [-2, -0.4], [-3, -1], [-0.5, -2]]]))
    #
    def _draw_gaussian(ax, gauss_tensor, color, weight=None):
        px, py, p_logsig_x, p_logsig_y = split_along_axis(ten2ar(gauss_tensor), axis=0)

        def logsig2std(logsig):
            return np.exp(logsig)

        ell = Ellipse(xy=(px, py),
                      width=2*logsig2std(p_logsig_x), height=2*logsig2std(p_logsig_y),
                      angle=0, color=color)     # this assumes diagonal gaussian
        if weight is not None:
            ell.set_alpha(weight)
        else:
            ell.set_facecolor('none')
        ax.add_artist(ell)
    #
    #
    # fig = plt.figure()
    # ax = plt.subplot(111)
    # plt.xlim(-2, 2); plt.ylim(-2, 2)
    # [_draw_gaussian(ax, component.tensor(), 'green', ten2ar(weight)) for weight, component in gmm[0]]
    #
    # samples = np.concatenate([gmm.sample() for _ in range(1000)])
    # plt.scatter(samples[:, 0], samples[:, 1])
    # plt.savefig("test.png")


    ### TRAIN
    import numpy as np
    import matplotlib.pyplot as plt
    from spirl.utils.general_utils import AttrDict
    from spirl.modules.layers import LayerBuilderParams
    from spirl.modules.subnetworks import Predictor

    # generate data
    pi = torch.tensor([0.7, 0.1, 0.1, 0.1])[None].repeat(256, 1)
    mu = torch.tensor([[1.0, -1.0, 0.0, 0.0], [0.0, 0.0, 1.0, -1.0]])[None].repeat(256, 1, 1).transpose(-1, -2)
    log_sigma = torch.zeros_like(mu) + torch.tensor(np.log(0.1))
    data_dist = GMM(pi=pi, mu=mu, log_sigma=log_sigma)

    data = data_dist.sample().data.numpy()

    # set up flow model
    trainable_input = torch.zeros((256, 2), requires_grad=True)
    hp = AttrDict({
        'nz_mid': 32,
        'n_processing_layers': 3,
    })
    hp.builder = LayerBuilderParams(False, 'batch')
    model = torch.nn.Sequential(
        Predictor(hp, input_size=2, output_size=hp.nz_mid),
        MDN(input_size=hp.nz_mid, output_size=2, num_gaussians=4)
    )

    pydata = torch.tensor(data, dtype=torch.float32)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

    # train flow model
    for i in range(6000):
        optimizer.zero_grad()
        gmm_dist = GMM(model(trainable_input))
        loss_samples = []
        for _ in range(10):
            data_sample = data_dist.sample()
            gmm_sample = gmm_dist.rsample()
            # loss = gmm_dist.nll(pydata).mean()
            # loss = (gmm_dist.log_prob(gmm_sample) - data_dist.log_prob(gmm_sample))
            loss = (data_dist.log_prob(data_sample) - gmm_dist.log_prob(data_sample))
            # loss = (gmm_dist.log_prob(gmm_sample) - data_dist.log_prob(gmm_sample)) + \
            #        (data_dist.log_prob(data_sample) - gmm_dist.log_prob(data_sample))
            loss_samples.append(loss)
        loss = torch.cat(loss_samples).mean()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Iter: {i}\t" +
                  f"NLL: {loss.mean().data:.2f}\t")

    # visualize samples
    samples = gmm_dist.sample().data.numpy()
    fig = plt.figure()
    ax = plt.subplot(111)
    plt.xlim(-2, 2); plt.ylim(-2, 2)
    # plt.scatter(data[:, 0], data[:, 1], c='black', alpha=0.1)
    # plt.scatter(samples[:, 0], samples[:, 1], c='green', alpha=0.5)
    [_draw_gaussian(ax, component.tensor(), 'green', ten2ar(weight)) for weight, component in gmm_dist[0]]
    plt.axis("equal")
    plt.savefig("gmm_fit.png")
    # plt.show()

