"""A module for a mixture density network layer."""
import torch
import torch.nn as nn
from torch.distributions import Categorical

from environments.kitchen.spirl.modules.variational_inference import (
    MultivariateGaussian,
)
from environments.kitchen.spirl.utils.pytorch_utils import ten2ar


class MDN(nn.Module):
    """A mixture density network layer"""

    def __init__(self, input_size, output_size, num_gaussians):
        super(MDN, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.num_gaussians = num_gaussians
        self.pi = nn.Sequential(nn.Linear(input_size, num_gaussians), nn.Softmax(dim=1))
        self.log_sigma = nn.Linear(input_size, output_size * num_gaussians)
        self.mu = nn.Linear(input_size, output_size * num_gaussians)

    def forward(self, inputs):
        return (
            torch.clamp(self.pi(inputs), min=1e-6),
            self.mu(inputs).reshape(-1, self.num_gaussians, self.output_size),
            torch.clamp(
                self.log_sigma(inputs).reshape(
                    -1, self.num_gaussians, self.output_size
                ),
                -10,
                2,
            ),
        )


class GMM:
    """Gaussian Mixture Model class."""

    def __init__(self, pi, mu=None, log_sigma=None):
        if mu is None and log_sigma is None:
            if isinstance(pi, tuple):
                pi, mu, log_sigma = pi  # in case inputs are passed in as tuple
            else:
                pi, mu, log_sigma = self.tensor2gmm(pi)
        self.pi = pi
        self.mu = mu
        self.log_sigma = log_sigma
        self._components = [
            MultivariateGaussian(mu[..., idx, :], log_sigma[..., idx, :])
            for idx in range(mu.shape[-2])
        ]

    def nll(self, x):
        return -1 * self.log_prob(x)

    def log_prob(self, x):
        return torch.logsumexp(
            torch.log(self.pi)
            + MultivariateGaussian(self.mu, self.log_sigma).log_prob(x[:, None]),
            dim=1,
        )

    def sample(self):
        """Differentiable sampling function."""
        return (
            MultivariateGaussian(self.mu, self.log_sigma).sample()
            * torch.nn.functional.one_hot(
                Categorical(self.pi).sample(), num_classes=self.pi.shape[-1]
            )[..., None].float()
        ).sum(dim=1)

    def rsample(self):
        return self.sample()

    def entropy(self):
        """!!! This is not the true entropy of the GMM (there is no closed form) but only an indicator. !!!"""
        return torch.stack([c.entropy() for c in self._components], dim=1)

    def detach(self):
        return GMM(self.pi.detach(), self.mu.detach(), self.log_sigma.detach())

    def tensor(self):
        """Returns flat tensor representation of GMM."""
        return torch.cat(
            (
                self.pi,
                self.mu.flatten(start_dim=1),
                self.log_sigma.flatten(start_dim=1),
                self.pi.shape[1]
                * torch.ones((self.pi.shape[0], 1), device=self.pi.device),
            ),
            dim=-1,
        )

    @staticmethod
    def tensor2gmm(tensor):
        """Unwraps flattened tensor representation generated by tensor() function."""
        num_gaussians = tensor[0, -1].long()
        nz = (tensor.shape[1] - 1 - num_gaussians) / num_gaussians / 2
        pi = tensor[:, :num_gaussians]
        mu = tensor[:, num_gaussians : num_gaussians + (num_gaussians * nz)].reshape(
            -1, num_gaussians, nz
        )
        log_sigma = tensor[:, -(num_gaussians * nz + 1) : -1].reshape(
            -1, num_gaussians, nz
        )
        return pi, mu, log_sigma

    def to_numpy(self):
        """Convert internal variables to numpy arrays."""
        return GMM(ten2ar(self.pi), ten2ar(self.mu), ten2ar(self.log_sigma))

    @staticmethod
    def stack(*argv, dim):
        return GMM._combine(torch.stack, *argv, dim=dim)

    @staticmethod
    def cat(*argv, dim):
        return GMM._combine(torch.cat, *argv, dim=dim)

    @staticmethod
    def _combine(fcn, *argv, dim):
        pi, mu, log_sigma = [], [], []
        for g in argv:
            pi.append(g.pi)
            mu.append(g.mu)
            log_sigma.append(g.log_sigma)
        pi, mu, log_sigma = fcn(pi, dim), fcn(mu, dim), fcn(log_sigma, dim)
        return GMM(pi, mu, log_sigma)

    def __getitem__(self, item):
        return GMM(self.pi[item], self.mu[item], self.log_sigma[item])

    def __iter__(self):
        for pi, c in zip(self.pi, self._components):
            yield pi, c


if __name__ == "__main__":
    ### VISUALIZE
    # from environments.kitchen.spirl.utils.pytorch_utils import ten2ar
    from environments.kitchen.spirl.utils.general_utils import split_along_axis

    # import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib.patches import Ellipse

    # gmm = GMM(torch.rand((1, 5)), torch.tensor([[[0., 0], [1, 1], [1, -1], [-1, 1], [-1, -1]]]),
    #           torch.tensor([[[-1, -0.3], [-2, -1], [-2, -0.4], [-3, -1], [-0.5, -2]]]))
    #
    def _draw_gaussian(ax, gauss_tensor, color, weight=None):
        px, py, p_logsig_x, p_logsig_y = split_along_axis(ten2ar(gauss_tensor), axis=0)

        def logsig2std(logsig):
            return np.exp(logsig)

        ell = Ellipse(
            xy=(px, py),
            width=2 * logsig2std(p_logsig_x),
            height=2 * logsig2std(p_logsig_y),
            angle=0,
            color=color,
        )  # this assumes diagonal gaussian
        if weight is not None:
            ell.set_alpha(weight)
        else:
            ell.set_facecolor("none")
        ax.add_artist(ell)

    #
    #
    # fig = plt.figure()
    # ax = plt.subplot(111)
    # plt.xlim(-2, 2); plt.ylim(-2, 2)
    # [_draw_gaussian(ax, component.tensor(), 'green', ten2ar(weight)) for weight, component in gmm[0]]
    #
    # samples = np.concatenate([gmm.sample() for _ in range(1000)])
    # plt.scatter(samples[:, 0], samples[:, 1])
    # plt.savefig("test.png")

    ### TRAIN
    import numpy as np
    import matplotlib.pyplot as plt
    from environments.kitchen.spirl.utils.general_utils import AttrDict
    from environments.kitchen.spirl.modules.layers import LayerBuilderParams
    from environments.kitchen.spirl.modules.subnetworks import Predictor

    # generate data
    pi = torch.tensor([0.7, 0.1, 0.1, 0.1])[None].repeat(256, 1)
    mu = (
        torch.tensor([[1.0, -1.0, 0.0, 0.0], [0.0, 0.0, 1.0, -1.0]])[None]
        .repeat(256, 1, 1)
        .transpose(-1, -2)
    )
    log_sigma = torch.zeros_like(mu) + torch.tensor(np.log(0.1))
    data_dist = GMM(pi=pi, mu=mu, log_sigma=log_sigma)

    data = data_dist.sample().data.numpy()

    # set up flow model
    trainable_input = torch.zeros((256, 2), requires_grad=True)
    hp = AttrDict(
        {
            "nz_mid": 32,
            "n_processing_layers": 3,
        }
    )
    hp.builder = LayerBuilderParams(False, "batch")
    model = torch.nn.Sequential(
        Predictor(hp, input_size=2, output_size=hp.nz_mid),
        MDN(input_size=hp.nz_mid, output_size=2, num_gaussians=4),
    )

    pydata = torch.tensor(data, dtype=torch.float32)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

    # train flow model
    for i in range(6000):
        optimizer.zero_grad()
        gmm_dist = GMM(model(trainable_input))
        loss_samples = []
        for _ in range(10):
            data_sample = data_dist.sample()
            gmm_sample = gmm_dist.rsample()
            # loss = gmm_dist.nll(pydata).mean()
            # loss = (gmm_dist.log_prob(gmm_sample) - data_dist.log_prob(gmm_sample))
            loss = data_dist.log_prob(data_sample) - gmm_dist.log_prob(data_sample)
            # loss = (gmm_dist.log_prob(gmm_sample) - data_dist.log_prob(gmm_sample)) + \
            #        (data_dist.log_prob(data_sample) - gmm_dist.log_prob(data_sample))
            loss_samples.append(loss)
        loss = torch.cat(loss_samples).mean()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Iter: {i}\t" + f"NLL: {loss.mean().data:.2f}\t")

    # visualize samples
    samples = gmm_dist.sample().data.numpy()
    fig = plt.figure()
    ax = plt.subplot(111)
    plt.xlim(-2, 2)
    plt.ylim(-2, 2)
    # plt.scatter(data[:, 0], data[:, 1], c='black', alpha=0.1)
    # plt.scatter(samples[:, 0], samples[:, 1], c='green', alpha=0.5)
    [
        _draw_gaussian(ax, component.tensor(), "green", ten2ar(weight))
        for weight, component in gmm_dist[0]
    ]
    plt.axis("equal")
    plt.savefig("gmm_fit.png")
    # plt.show()
