import torch.nn as nn
import numpy as np
import torch
import time
import math
import torch.nn.functional as F


#NF model
class NF(nn.Module):

    #normalizing flow

    def __init__(self, ndim, layers, Nconditional, width=None, depth=None):
        
        super().__init__()
        
        self.layer = nn.ModuleList(layers)
        self.ndim = ndim
        if Nconditional:
            self.nparam = 0
            for layer in self.layer:
                self.nparam += layer.nparam
            self.net = MLP(Nconditional, width, depth, self.nparam)
            self.conditional = True
        else:
            self.conditional = False

    
    def forward(self, data, label=None):
        
        logj = torch.zeros(data.shape[0], device=data.device)

        if self.conditional:
            param = self.net(label)
        
        start = 0
        for i in range(len(self.layer)):
            if self.layer[i].nparam > 0:
                end = start + self.layer[i].nparam
                data, log_j = self.layer[i](data, param=param[:,start:end])
                start = end
            else:
                data, log_j = self.layer[i](data)
            logj += log_j
            #print(i, torch.max(data).item(), torch.min(data).item())

        return data, logj
    
    
    def inverse(self, data, label=None):

        logj = torch.zeros(data.shape[0], device=data.device)

        if self.conditional:
            param = self.net(label)
            end = param.shape[1]

        for i in reversed(range(len(self.layer))):
            if self.layer[i].nparam > 0:
                start = end - self.layer[i].nparam
                data, log_j = self.layer[i].inverse(data, param=param[:,start:end])
                end = start
            else:
                data, log_j = self.layer[i].inverse(data)
            logj += log_j

        return data, logj


    def evaluate_density(self, data, label=None):
        
        data, logj = self.forward(data, label)
        logq = -self.ndim/2*math.log(2*math.pi) - torch.sum(data.reshape(len(data), self.ndim)**2,  dim=1)/2
        logp = logj + logq
        
        return logp


    def loss(self, data, label=None):
        return -torch.mean(self.evaluate_density(data, label))
    
    
    def sample(self, label):

        x = torch.randn(len(label), self.ndim, device=label.device)
        logq = -self.ndim/2.*torch.log(torch.tensor(2.*math.pi)) - torch.sum(x**2,  dim=1)/2
        x, logj = self.inverse(x, label)
        logp = logj + logq

        return x, logp


#flow transform

class log_transform(nn.Module):

    #log transform

    def __init__(self, lambd=1e-5):

        super().__init__()
        self.lambd = lambd
        self.nparam = 0


    def forward(self, data):

        assert torch.min(data) >= 0

        shape = data.shape
        data = data.reshape(len(data), -1)

        data = self.lambd + data 
        data = torch.log(data)
        logj = torch.sum(-data, axis=1)
        return data.reshape(*shape), logj


    def inverse(self, data):

        shape = data.shape
        data = data.reshape(len(data), -1)

        logj = torch.sum(-data, axis=1)
        data = torch.exp(data) 
        data = data - self.lambd 
        return data.reshape(*shape), logj



class exp_transform(nn.Module):

    #exp transform

    def __init__(self, lambd=1e-5):

        super().__init__()
        self.lambd = lambd
        self.nparam = 0


    def inverse(self, data):

        assert torch.min(data) >= 0

        shape = data.shape
        data = data.reshape(len(data), -1)

        data = self.lambd + data
        data = torch.log(data)
        logj = torch.sum(data, axis=1)
        return data.reshape(*shape), logj


    def forward(self, data):

        shape = data.shape
        data = data.reshape(len(data), -1)

        logj = torch.sum(data, axis=1)
        data = torch.exp(data)
        data = data - self.lambd
        return data.reshape(*shape), logj



class Softplus_transform(nn.Module):

    #softplus transform (-inf, inf) -> (0, inf)
    #linear when x is large

    def __init__(self, beta=1, threshold=20, eps=0):

        super().__init__()
        self.beta = beta
        self.threshold = threshold 
        self.nparam = 0
        self.eps = eps


    def forward(self, data):

        select = data < self.threshold/self.beta 
        z = data.clone()
        z[select] = torch.log(torch.exp(self.beta*data[select])+1-self.eps) / self.beta

        logj = self.beta * (data - z) 
        logj = torch.sum(logj.reshape(len(data), -1), dim=1)
        return z, logj


    def inverse(self, data):

        select = data < self.threshold/self.beta 
        x = data.clone()
        x[select] = torch.log(torch.exp(self.beta*data[select])+self.eps-1) / self.beta

        logj = self.beta * (x - data) 
        logj = torch.sum(logj.reshape(len(data), -1), dim=1)
        return x, logj



class InvSoftplus_transform(nn.Module):

    #inverse softplus transform (0, inf) -> (-inf, inf)
    #linear when x is large

    def __init__(self, beta=1, threshold=20, mask=None, eps=0):

        super().__init__()
        self.beta = beta
        self.threshold = threshold 
        if mask is not None:
            self.register_buffer('mask', mask)
        else:
            self.mask = None
        self.eps = eps
        self.nparam = 0


    def forward(self, data):

        if self.mask is not None:
            data0 = data.clone()
            data = data[:, self.mask]
        select = data < self.threshold/self.beta
        z = data.clone()
        z[select] = torch.log(torch.exp(self.beta*data[select])+self.eps-1) / self.beta

        logj = self.beta * (data - z) 
        logj = torch.sum(logj.reshape(len(data), -1), dim=1)
        if self.mask is not None:
            data0[:, self.mask] = z
            return data0, logj
        else:
            return z, logj


    def inverse(self, data):

        if self.mask is not None:
            data0 = data.clone()
            data = data[:, self.mask]
        select = data < self.threshold/self.beta
        x = data.clone()
        x[select] = torch.log(torch.exp(self.beta*data[select])+1-self.eps) / self.beta

        logj = self.beta * (x - data) 
        logj = torch.sum(logj.reshape(len(data), -1), dim=1)
        if self.mask is not None:
            data0[:, self.mask] = x
            return data0, logj
        else:
            return x, logj



class logit_transform(nn.Module):

    #logit transform (0, 1) -> (-inf, inf)

    def __init__(self, lambd=1e-5, mask=None):

        super().__init__()
        self.lambd = lambd
        if mask is not None:
            self.register_buffer('mask', mask)
        else:
            self.mask = None
        self.nparam = 0


    def forward(self, data):

        if self.mask is not None:
            data0 = data.clone()
            data = data[:,self.mask]
        assert torch.min(data) >= 0 and torch.max(data) <= 1

        data = self.lambd + (1 - 2 * self.lambd) * data 
        logj = torch.sum(-torch.log(data*(1-data)).reshape(len(data), -1) + math.log(1-2*self.lambd), axis=1)
        data = torch.log(data) - torch.log1p(-data)
        if self.mask is not None:
            data0[:,self.mask] = data
            return data0, logj
        else:
            return data, logj


    def inverse(self, data):

        if self.mask is not None:
            data0 = data.clone()
            data = data[:,self.mask]
        data = torch.sigmoid(data) 
        logj = torch.sum(-torch.log(data*(1-data)).reshape(len(data), -1) + math.log(1-2*self.lambd), axis=1)
        data = (data - self.lambd) / (1. - 2 * self.lambd)
        if self.mask is not None:
            data0[:,self.mask] = data
            return data0, logj
        else:
            return data, logj


class Normalization_transform(nn.Module):
    """Transform that performs normalization.
    """

    def __init__(self, eps=1e-5, momentum=0.01, Nestimate=500, mask=None):
        super().__init__()

        self.momentum = momentum
        self.eps = eps
        self.nparam = 0
        self.count = 0
        self.Nestimate = Nestimate

        self.register_buffer("running_mean", torch.zeros(1))
        self.register_buffer("running_var", torch.ones(1))

    def forward(self, inputs, param=None, label=None):

        if self.training and self.count < self.Nestimate:
            mean, var = torch.mean(inputs), torch.var(inputs)
            if self.count:
                self.running_mean.mul_(1 - self.momentum).add_(mean.detach() * self.momentum)
                self.running_var.mul_(1 - self.momentum).add_(var.detach() * self.momentum)
            else:
                self.running_mean[:] = mean.detach()
                self.running_var[:] = var.detach()
            self.count = self.count + 1

        outputs = (inputs - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        ndim = np.prod(inputs[0].shape)
        
        logabsdet = - 0.5 * torch.log(self.running_var + self.eps) * ndim

        return outputs, logabsdet

    def inverse(self, inputs, param=None, label=None):

        outputs = torch.sqrt(self.running_var + self.eps) * inputs + self.running_mean
        ndim = np.prod(inputs[0].shape)

        logabsdet = - 0.5 * torch.log(self.running_var + self.eps) * ndim

        return outputs, logabsdet


class BatchNorm_transform(nn.Module):
    """Transform that performs batch normalization.
    Assumes translational invariance, so the mean and variance are calculated using all pixels.
    Limitations:
        * Inverse is not available in training mode, only in eval mode.
    """

    def __init__(self, eps=1e-5, momentum=0.1, affine=True, conditional=False):
        super().__init__()

        self.momentum = momentum
        self.eps = eps
        self.affine = affine
        self.conditional = conditional
        if self.affine and conditional:
            self.nparam = 2
        elif self.affine:
            self.weight_bias = nn.Parameter(torch.zeros(2))
            self.nparam = 0
        else:
            self.nparam = 0

        self.register_buffer("running_mean", torch.zeros(1))
        self.register_buffer("running_var", torch.zeros(1))

    def weight(self, unconstrained_weight):
        return F.softplus(unconstrained_weight) + self.eps

    def forward(self, inputs, param=None):

        if self.training:
            mean, var = torch.mean(inputs), torch.var(inputs)
            self.running_mean.mul_(1 - self.momentum).add_(mean.detach() * self.momentum)
            self.running_var.mul_(1 - self.momentum).add_(var.detach() * self.momentum)
        else:
            mean, var = self.running_mean, self.running_var

        shape = inputs.shape
        inputs = inputs.reshape(len(inputs), -1)
        if self.affine and self.conditional:
            weight = self.weight(param[:,0]).reshape(-1,1)
            bias = param[:,1].reshape(-1,1)
        elif self.affine:
            weight = torch.repeat_interleave(self.weight(self.weight_bias[0]).reshape(1,1), len(inputs), dim=0)
            bias = torch.repeat_interleave(self.weight_bias[1].reshape(1,1), len(inputs), dim=0)
        else:
            weight = torch.ones(len(inputs), device=inputs.device).reshape(-1,1)
            bias = torch.zeros(len(inputs), device=inputs.device).reshape(-1,1)
        outputs = (
            weight * ((inputs - mean) / torch.sqrt((var + self.eps))) + bias
        )

        logabsdet = (torch.log(weight.reshape(-1)) - 0.5 * torch.log(var + self.eps)) * inputs.shape[1]
        outputs = outputs.reshape(*shape)

        return outputs, logabsdet

    def inverse(self, inputs, param=None):
        if self.training:
            raise NotImplementedError(
                "Batch norm inverse is only available in eval mode, not in training mode."
            )

        shape = inputs.shape
        inputs = inputs.reshape(len(inputs), -1)
        if self.affine and self.conditional:
            weight = self.weight(param[:,0]).reshape(-1,1)
            bias = param[:,1].reshape(-1,1)
        elif self.affine:
            weight = torch.repeat_interleave(self.weight(self.weight_bias[0]).reshape(1,1), len(inputs), dim=0)
            bias = torch.repeat_interleave(self.weight_bias[1].reshape(1,1), len(inputs), dim=0)
        else:
            weight = torch.ones(len(inputs), device=inputs.device).reshape(-1,1)
            bias = torch.zeros(len(inputs), device=inputs.device).reshape(-1,1)
        outputs = (
            torch.sqrt(self.running_var + self.eps)
            * ((inputs - bias) / weight)
            + self.running_mean
        )

        logabsdet = (torch.log(weight.reshape(-1)) - 0.5 * torch.log(self.running_var + self.eps)) * inputs.shape[1]
        outputs = outputs.reshape(*shape)

        return outputs, logabsdet



class MLP(nn.Module):

    def __init__(self, input, width, depth, output):
        super(MLP, self).__init__()
        if depth == 1:
            self.fc = nn.ModuleList([nn.Linear(input, output)])
        else:
            self.fc = nn.ModuleList([nn.Linear(input, width)])
            for i in range(depth-2):
                self.fc.append(nn.Linear(width, width))
            self.fc.append(nn.Linear(width, output))

    def forward(self, x):
        for i in range(len(self.fc)):
            x = self.fc[i](x)
            if i < len(self.fc) - 1:
                x = F.relu(x)
        return x

