import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as D


def batch_KL_diag_gaussian_std(mu_1, std_1, mu_2, std_2):
    diag_1 = std_1 ** 2
    diag_2 = std_2 ** 2
    ratio = diag_1 / diag_2
    return 0.5 * (
        torch.sum((mu_1 - mu_2) ** 2 / diag_2, dim=-1)
        + torch.sum(ratio, dim=-1)
        - torch.sum(torch.log(ratio), dim=-1)
        - mu_1.size(1)
    )


def discretized_mix_logistic_uniform(x, l, alpha=0.0001):
    xs=list(x.size())
    x=x.unsqueeze(2)
    mix_num = int(l.size(1)/10) 
    pi = torch.softmax(l[:, :mix_num,:,:],1).unsqueeze(1).repeat(1,3,1,1,1)
    l=l[:, mix_num:,:,:].view(xs[:2]+[-1]+xs[2:])
    means = l[:, :, :mix_num, :,:]
    inv_stdv = torch.exp(-torch.clamp(l[:, :, mix_num:2*mix_num,:, :], min=-7.))
    coeffs = torch.tanh(l[:, :, 2*mix_num:, : ,  : ])
    m2 = means[:,  1:2, :,:, :]+coeffs[:,  0:1, :,:, :]* x[:, 0:1, :,:, :]
    m3 = means[:,  2:3, :,:, :]+coeffs[:,  1:2, :,:, :] * x[:, 0:1,:,:, :]+coeffs[:,  2:3,:,:, :] * x[:,  1:2,:,:, :]
    means = torch.cat((means[:, 0:1,:, :, :],m2, m3), dim=1)
    centered_x = x - means
    cdf_plus = torch.sigmoid(inv_stdv * (centered_x + 1. / 510.))
    cdf_plus=torch.where(x > 0.9995, torch.tensor(1.0).to(x.device),cdf_plus)
    cdf_min = torch.sigmoid(inv_stdv * (centered_x - 1. / 510.))
    cdf_min=torch.where(x < 0.0005, torch.tensor(0.0).to(x.device),cdf_min)
    log_probs =torch.log((1-alpha)*(pi*(cdf_plus-cdf_min)).sum(2)+alpha*(1/256))
    return log_probs.sum([1,2,3])


def discretized_mix_logistic_sample(l):
    nr_mix= int(l.size(1)/10) 
    l = l.permute(0, 2, 3, 1)
    ls = [int(y) for y in l.size()]
    xs = ls[:-1] + [3]
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
    temp = torch.clamp(torch.rand_like(logit_probs),1e-5, 1. - 1e-5).to(l.device)
    temp = logit_probs.data - torch.log(- torch.log(temp))
    _, argmax = temp.max(dim=3)
    one_hot = F.one_hot(argmax, nr_mix)
    sel = one_hot.view(xs[:-1] + [1, nr_mix])
    means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 
    log_scales = torch.clamp(torch.sum(l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.)
    coeffs = torch.sum(F.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4)
    u = torch.clamp(torch.rand_like(means),1e-5, 1. - 1e-5).to(l.device)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
    x0 = torch.clamp(x[:, :, :, 0], min=0., max=1.)
    x1 = torch.clamp(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=0., max=1.)
    x2 = torch.clamp(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=0., max=1.)
    out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3)
    out = out.permute(0, 3, 1, 2)
    return (out*255).int()/255.


class Residual(nn.Module):
    def __init__(self, channels):
        super(Residual, self).__init__()
        self.block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 1, bias=False),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class fc_encoder(nn.Module):
    def __init__(self, channels=256, latent_channels=64):
        super(fc_encoder, self).__init__()
        self.latent_channels=latent_channels
        self.encoder = nn.Sequential(
            nn.Conv2d(3, channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels),
            Residual(channels),
            Residual(channels),
            nn.Conv2d(channels, latent_channels*2, 1)
        )

    def forward(self, x):
        z=self.encoder(x)
        return z[:,:self.latent_channels,:,:].view(x.size(0),-1),F.softplus(z[:,self.latent_channels:,:,:].view(x.size(0),-1))


class fc_decoder(nn.Module):
    def __init__(self, channels=256, latent_channels=64, out_channels=100):
        super(fc_decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d( latent_channels, channels, 1, bias=False),
            nn.BatchNorm2d(channels),
            Residual(channels),
            Residual(channels),
            nn.ConvTranspose2d(channels, channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(True),
            nn.ConvTranspose2d(channels, channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, out_channels, 1)
        )

    def forward(self, z):
        # print('here',z.size(0))
        return  self.decoder(z.view(z.size(0),-1,8,8))




# CIFAR w/ mix logistic
class VAE_L(nn.Module):
    def __init__(self, device, channels=256, z_channels=2):
        super().__init__()
        self.z_channels = z_channels
        self.channels = channels
        self.z_dims=self.z_channels*8*8
        self.device=device
        self.img_channels = 3
        self.img_HW = 32
        self.likelihood_family = "MoL"

        self.encoder=fc_encoder(channels=self.channels, latent_channels=self.z_channels)
        self.decoder=fc_decoder(channels=self.channels, latent_channels=self.z_channels, out_channels=100)       
        self.criterion  = lambda  data,params :discretized_mix_logistic_uniform(data, params)
        self.sample_op = lambda  params : discretized_mix_logistic_sample(params)

        # self.prior_mu=torch.zeros(self.z_dims, requires_grad=False)
        # self.prior_std=torch.ones(self.z_dims, requires_grad=False)

    def reparameterize(self, mu, std):
        eps = torch.randn(mu.size())
        eps = eps.to(self.device)
        return mu + eps * std

    def loglikelihood_x_y(self, x, fz):
        ll = self.criterion(x, fz)
        return ll

    def q_z(self, x):
        z_mu, z_std = self.encoder(x)
        return self.reparameterize(z_mu, z_std), z_mu, z_std

    def p_x(self, z):
        fz = self.decoder(z)
        return fz

    def forward(self, x):
        z, qz_mu, qz_std = self.q_z(x)
        fz = self.p_x(z)
        ll = self.loglikelihood_x_y(x, fz)
        # kl = batch_KL_diag_gaussian_std(qz_mu,qz_std,self.prior_mu.to(self.device),self.prior_std.to(self.device))

        qz = D.normal.Normal(qz_mu, qz_std)
        qz = D.independent.Independent(qz, 1)
        pz = D.normal.Normal(torch.zeros_like(z), torch.ones_like(z))
        pz = D.independent.Independent(pz, 1)

        # For: KL[q(z|x) || p(z)]
        kl = D.kl.kl_divergence(qz, pz)

        elbo = ll - kl
        return -elbo.mean(), ll.mean(), kl.mean()
    

    def sample_x(self,num=100):
        with torch.no_grad():
            eps = torch.randn(num,self.z_dims).to(self.device)
            fz = self.decoder(eps)
            return self.sample_op(fz)

    def reconstruction(self, x, use_sample=False):
        with torch.no_grad():
            z_sample, z_mu, _ = self.q_z(x)
            if use_sample:
                fz = self.p_x(z_sample)
            else:
                fz = self.p_x(z_mu)
            return self.sample_op(fz)





    
    
    
    