import torch
from  torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch.nn import init
from torch.distributions import transforms
import math

class InvLinear(transforms.Transform, Module):
    
    def __init__(self, features, bias = True, stability_factor=1e-8, device=torch.device('cpu')):
        super(InvLinear, self).__init__()
        self.domain = torch.distributions.constraints.Constraint()
        self.codomain = torch.distributions.constraints.Constraint()
        self.stability_factor = stability_factor
        self.num_features = features
        self.device=device
        self.weight = Parameter(torch.Tensor(features, features).to(self.device))
        if bias:
            self.bias = Parameter(torch.Tensor(features).to(self.device))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        

    def __hash__(self):
        return Module.__hash__(self)

    def current_bias(self):
        if self.bias is not None:
            return self.bias
        else:
            return torch.zeros(self.num_features).to(self.device)

    def weight_mat(self):
        return torch.tril(self.weight, diagonal=-1) +  torch.diag(torch.diag(self.weight).abs() + self.stability_factor)
    
    def reset_parameters(self):
        init.eye_(self.weight)
        #init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
            
    def forward(self, input):
        return torch.nn.functional.linear(input, self.weight_mat(), self.current_bias())
    def _call(self, input):
        return self.forward(input)

    def _inverse(self, input):
        return torch.triangular_solve((input - self.current_bias()).unsqueeze(-1), self.weight_mat(), upper=False)[0].squeeze()

    def log_abs_det_jacobian(self, x, z):
        return torch.log(torch.diag(self.weight_mat()).abs() + 1e-9).sum(-1)
    
    def extra_repr(self):
        return 'features={}, bias={}'.format(
           self.num_features, self.bias is not None
        )

class CondInvLinear(InvLinear):
    def __init__(self, features, bias = True, stability_factor=1e-8, device = torch.device('cpu')):
        super(CondInvLinear, self).__init__(features, bias , stability_factor, device)
        self.domain = torch.distributions.constraints.Constraint()
        self.codomain = torch.distributions.constraints.Constraint()
        self.conditioner = torch.zeros([self.num_features]).to(self.device)
        self.reset_parameters()

    def reset_parameters(self):
        init.eye_(self.weight)
        #init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
        self.conditioner = torch.zeros([self.num_features]).to(self.device)

    def condition(self, conditioner):
        self.conditioner = conditioner

    def current_bias(self):
        if self.bias is not None:
            return self.bias + self.conditioner
        else:
            return self.conditioner


class CondScale(transforms.Transform, Module):
    
    def __init__(self, features, bias = True, stability_factor=1e-8, device = torch.device('cpu')):
        super(CondScale, self).__init__()
        self.domain = torch.distributions.constraints.Constraint()
        self.codomain = torch.distributions.constraints.Constraint()
        self.device=device
        self.stability_factor = stability_factor
        self.num_features = features
        self.weight = Parameter(torch.Tensor(1).to(self.device))
        if bias:
            self.bias = Parameter(torch.Tensor(self.num_features).to(self.device))
        else:
            self.register_parameter('bias', None)
        self.conditioner = torch.zeros([self.num_features]).to(self.device)

        self.reset_parameters()
        

    def __hash__(self):
        return Module.__hash__(self)

    def current_bias(self):
        if self.bias is not None:
            return self.bias
        else:
            return torch.zeros(self.num_features).to(self.device)

    def weight_mat(self):
        return ((self.weight).abs() + self.stability_factor)*torch.eye(self.num_features).to(self.device)
    
    def reset_parameters(self):
        self.weight.data = torch.Tensor([1.0]).to(self.device)
        if self.bias is not None:
            self.bias.data = torch.randn([self.num_features]).to(self.device)


            
    def forward(self, input):
        return input*((self.weight).abs() + self.stability_factor) + self.current_bias()
    def _call(self, input):
        return self.forward(input)

    def _inverse(self, input):
        return ((input - self.current_bias())/((self.weight).abs() + self.stability_factor))

    def log_abs_det_jacobian(self, x, z):
        return self.num_features*torch.log((self.weight).abs() + self.stability_factor)
    
    def extra_repr(self):
        return 'features={}, bias={}'.format(
           self.num_features, self.bias is not None
        )

    def condition(self, conditioner):
        self.conditioner = conditioner

    def current_bias(self):
        if self.bias is not None:
            return self.bias + self.conditioner
        else:
            return self.conditioner

