import torch
import numpy as np
from torch import nn
from torch.nn import Module
from torch.nn import functional as F

c = torch.tensor(- 0.5 * np.log(2 * np.pi))


def log_normal(x, mean=0., log_var=None, eps=0.00001):
    sqr_dist = (x-mean) ** 2
    if log_var is None:
        log_var = 0.
        var = 1.
    else:
        var = torch.exp(log_var)
    return -sqr_dist / (2. * var + eps) - log_var / 2. + c


class StandardGaussian(Module):
    def __init__(self, input_size, conditioned_size=0):
        super(StandardGaussian, self).__init__()
        self.ndim = input_size
    def forward(self, z, _=None):
        return log_normal(z).sum(dim=-1)

    def sample(self, batch_size=None, device=None):
        eps = torch.randn((batch_size, self.ndim)).to(device)
        return eps




class Logistic(nn.Module):
    def __init__(self, input_size=None):
        super(Logistic, self).__init__()
        self.register_buffer('corr_f', torch.tensor(1.81))

    @staticmethod
    def _base(z_0):
        d = F.logsigmoid(z_0) + F.logsigmoid(-z_0)

        nan_parts = ~torch.isfinite(d)
        if nan_parts.any():
            print(z_0[nan_parts])
        return d

    def inv_cdf(self, p):
        p = torch.clamp(p, min=1e-8)
        z_0 = -torch.log(torch.clamp(torch.reciprocal(p) - 1, min=1e-8))
        z_1 = z_0 / self.corr_f
        return z_1

    def cdf(self, z):
        z_0 = z * self.corr_f
        p = torch.sigmoid(z_0)
        return p

    def sample(self, trunc_params=None, eps=None):
        if eps is None:
            eps = torch.rand_like(trunc_params[..., 0])
        z = self.inv_cdf(eps)
        return z, self.log_probability(z)

    def log_probability(self, z, trunc_params=None, return_zero_mask=False):
        log_p_per_dim = self._base(z * self.corr_f) + torch.log(self.corr_f)
        log_p = log_p_per_dim.sum(-1)
        if return_zero_mask:
            return log_p, torch.zeros_like(log_p, dtype=torch.bool)
        else:
            return log_p

    def forward(self, z):
        return self.log_probability(z)


class Mixture(nn.Module):
    def __init__(self, input_size=None,
                 components=5, base_distribution=Logistic()):
        super(Mixture, self).__init__()
        self.components = components
        self.input_size = input_size
        self.parameter_count = input_size * components * (2 + 1)
        self.base = base_distribution

    def log_probability(self, z, params):
        assert params.size(-1) == self.parameter_count
        params = params.view(*(params.size()[:-1]),
                             self.input_size,
                             self.components, 3)
        """
        print(z.size(), params.size())
        for i in range(z.size(-1)):
            for j in range(params.size(-2)):
                for k in range(params.size(-1)):
                    grd, = torch.autograd.grad(params[..., 0, i, j, 0], z, retain_graph=True)
                    print("dim", i, "comp", j, "param", k, grd[..., 0, :])
        """
        log_pi = F.log_softmax(params[..., 0], dim=-1)
        shift = params[..., 1]
        log_scale = params[..., 2]
        scale = torch.exp(-log_scale)
        z_0 = (z[..., None] - shift) * scale
        log_base_probs = self.base.log_probability(z_0[..., None]) - log_scale
        log_probs = torch.logsumexp(log_pi + log_base_probs, dim=-1)
        return log_probs.sum(-1)

    def forward(self, z, params):
        return self.log_probability(z, params)


if __name__ == "__main__":
    mix_log = Mixture(input_size=5, components=3, base_distribution=Logistic())
    z = torch.randn(2, 2, 5)
    params = torch.randn(2, 2, 45)
    print(mix_log.log_probability(z, params))




from .made import *
from .flow import *
