import torch
from  torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch.nn import init
from torch.distributions import transforms
import math

class AdditiveCoupling(Module):
    
    def __init__(self, features, mask=None, depth=1, width=None, device=torch.device('cpu')):
        super(AdditiveCoupling, self).__init__()
        self.num_features = features
        self.device=device
        if mask is not None:
            self.mask = mask.bool().to(self.device)
        else:
            self.mask = torch.Tensor(features).bernoulli_(0.5)
            if self.mask.sum() == 0:
                print(self.mask, self.mask.sum())
                self.mask[0] = 1.0
            if self.mask.sum() == self.num_features:
                print(self.mask, self.mask.sum())
                self.mask[0] = 0.0
            self.mask = self.mask.bool().to(self.device)
        self.depth = depth
        if width is not None:
            self.width = width
        else:
            self.width = int(self.mask.sum())
        modules = [torch.nn.Linear(int(self.mask.sum()), self.width).to(self.device)]
        for i in range(0, depth):
            modules += [torch.nn.LeakyReLU(), torch.nn.Linear(self.width, self.width).to(self.device)]
        modules += [torch.nn.Linear(self.width, int(features - self.mask.sum())).to(self.device)]
        self.mask_mat = torch.diag(self.mask.float())[self.mask.bool()]
        self.mask_mat_t = torch.diag(self.mask.float())[:,self.mask.bool()]#torch.transpose(self.mask_mat, 0, 1)
        self.comp_mask_mat = torch.diag(1.0-self.mask.float())[:,~self.mask.bool()]
            
        self.network = torch.nn.Sequential(*modules)
        

   
            
    def forward(self, input):


        out = input[1]*torch.nn.functional.linear(self.network(torch.nn.functional.linear(input[0], self.mask_mat)), self.comp_mask_mat) + input[0]
        return  out, input[1]
    
    def extra_repr(self):
        return 'features={}, depth={}'.format(
           self.num_features, self.depth
        )

class Scaling(Module):
    
    def __init__(self, features, device=torch.device('cpu')):
        super(Scaling, self).__init__()
        self.num_features = features
        self.device=device
        self.weights = Parameter(torch.zeros(self.num_features).to(device))

    def forward(self, input):
        return  input[0] * torch.exp(torch.clamp(input[1] * self.weights, min=-20.0, max=20.0)) , input[1]
    
    def extra_repr(self):
        return 'features={}, depth={}'.format(
           self.num_features, self.depth
        )


class NICE(transforms.Transform, Module):
    
    def __init__(self, features, alternating_mask=True,num_layers=3, layer_depth=1, layer_width=None, device=torch.device('cpu')):
        super(NICE, self).__init__()
        self.domain = torch.distributions.constraints.Constraint()
        self.codomain = torch.distributions.constraints.Constraint()
        self.num_features = features
        self.num_layers = num_layers
        self.layer_depth = layer_depth
        self.layer_width=layer_width
        self.alternating_mask = alternating_mask
        self.device=device
        modules = [Scaling(features, device=device)]
        current_mask = torch.zeros(features).to(device)
        current_mask[::2] = 1
        current_mask = current_mask.bool()
        for layer in range(0, num_layers):
            if alternating_mask:
                current_mask = ~current_mask
            else:
                current_mask = None
            modules.append(AdditiveCoupling(features, mask=current_mask, depth=layer_depth, width=layer_width, device=device))
        self.forward_network = torch.nn.Sequential(*modules)
        self.reverse_network = torch.nn.Sequential(*modules[::-1])
        self.direction = torch.ones(1).to(device)

    def __hash__(self):
        return Module.__hash__(self)

    def network_application(self, input, reverse):
        return self.network((input, reverse))
    
    def forward(self, input):
        return  self.forward_network((input, self.direction))[0]

    def _call(self, input):
        return self.forward(input)

    def _inverse(self, input):
        return  self.reverse_network((input, -self.direction))[0]
    
    def log_abs_det_jacobian(self, x, z):
        return torch.clamp(self.forward_network[0].weights, min=-20.0, max=20.0).sum()
    
    def extra_repr(self):
        return 'features={}, depth={}'.format(
           self.num_features, self.depth
        )

