'''
Code taken from https://github.com/zalandoresearch/pytorch-ts
With some modifications made


'''

import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from nflows.flows.autoregressive import MaskedAutoregressiveFlow

'''
def create_masks(
    input_size, hidden_size, n_hidden, input_order="sequential", input_degrees=None
):
    # MADE paper sec 4:
    # degrees of connections between layers -- ensure at most in_degree - 1 connections
    degrees = []

    # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);
    # else init input degrees based on strategy in input_order (sequential or random)
    if input_order == "sequential":
        degrees += (
            [torch.arange(input_size)] if input_degrees is None else [input_degrees]
        )
        for _ in range(n_hidden + 1):
            degrees += [torch.arange(hidden_size) % (input_size - 1)]
        degrees += (
            [torch.arange(input_size) % input_size - 1]
            if input_degrees is None
            else [input_degrees % input_size - 1]
        )

    elif input_order == "random":
        degrees += (
            [torch.randperm(input_size)] if input_degrees is None else [input_degrees]
        )
        for _ in range(n_hidden + 1):
            min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
            degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))]
        min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
        degrees += (
            [torch.randint(min_prev_degree, input_size, (input_size,)) - 1]
            if input_degrees is None
            else [input_degrees - 1]
        )

    ##print('degrees: ', degrees)
    # construct masks
    ##print('Constructing masks')
    masks = []
    for (d0, d1) in zip(degrees[:-1], degrees[1:]):
        ##print(d1.unsqueeze(-1))
        ##print(d0.unsqueeze(0))
        masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()]

    return masks, degrees[0]



class BatchNorm(nn.Module):
    """ RealNVP BatchNorm layer """

    def __init__(self, input_size, momentum=0.9, eps=1e-5):
        super().__init__()
        self.momentum = momentum
        self.eps = eps

        self.log_gamma = nn.Parameter(torch.zeros(input_size))
        self.beta = nn.Parameter(torch.zeros(input_size))

        self.register_buffer("running_mean", torch.zeros(input_size))
        self.register_buffer("running_var", torch.ones(input_size))

    def forward(self, x, cond_y=None):
        if self.training:
            self.batch_mean = x.view(-1, x.shape[-1]).mean(0)
            # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False)
            self.batch_var = x.view(-1, x.shape[-1]).var(0)

            # update running mean
            self.running_mean.mul_(self.momentum).add_(
                self.batch_mean.data * (1 - self.momentum)
            )
            self.running_var.mul_(self.momentum).add_(
                self.batch_var.data * (1 - self.momentum)
            )

            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        # compute normalized input (cf original batch norm paper algo 1)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        y = self.log_gamma.exp() * x_hat + self.beta

        # compute log_abs_det_jacobian (cf RealNVP paper)
        log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps)
        #                #            (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean()))
        return y, log_abs_det_jacobian.expand_as(x)

    def inverse(self, y, cond_y=None):
        if self.training:
            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (y - self.beta) * torch.exp(-self.log_gamma)
        x = x_hat * torch.sqrt(var + self.eps) + mean

        log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma

        return x, log_abs_det_jacobian.expand_as(x)


class LinearMaskedCoupling(nn.Module):
    """ Modified RealNVP Coupling Layers per the MAF paper """

    def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None):
        super().__init__()

        self.register_buffer("mask", mask)

        # scale function
        s_net = [
            nn.Linear(
                input_size + (cond_label_size if cond_label_size is not None else 0),
                hidden_size,
            )
        ]
        for _ in range(n_hidden):
            s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)]
        s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)]
        self.s_net = nn.Sequential(*s_net)

        # translation function
        self.t_net = copy.deepcopy(self.s_net)
        # replace Tanh with ReLU's per MAF paper
        for i in range(len(self.t_net)):
            if not isinstance(self.t_net[i], nn.Linear):
                self.t_net[i] = nn.ReLU()

    def forward(self, x, y=None):
        # apply mask
        mx = x * self.mask

        # run through model
        s = self.s_net(mx if y is None else torch.cat([y, mx], dim=-1))
        t = self.t_net(mx if y is None else torch.cat([y, mx], dim=-1)) * (
            1 - self.mask
        )

        # cf RealNVP eq 8 where u corresponds to x (here we're modeling u)
        log_s = torch.tanh(s) * (1 - self.mask)
        u = x * torch.exp(log_s) + t
        # u = (x - t) * torch.exp(log_s)
        # u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)

        # log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob
        # log_abs_det_jacobian = -(1 - self.mask) * s
        # log_abs_det_jacobian = -log_s #.sum(-1, keepdim=True)
        log_abs_det_jacobian = log_s

        return u, log_abs_det_jacobian

    def inverse(self, u, y=None):
        # apply mask
        mu = u * self.mask

        # run through model
        s = self.s_net(mu if y is None else torch.cat([y, mu], dim=-1))
        t = self.t_net(mu if y is None else torch.cat([y, mu], dim=-1)) * (
            1 - self.mask
        )

        log_s = torch.tanh(s) * (1 - self.mask)
        x = (u - t) * torch.exp(-log_s)
        # x = u * torch.exp(log_s) + t
        # x = mu + (1 - self.mask) * (u * s.exp() + t)  # cf RealNVP eq 7

        # log_abs_det_jacobian = (1 - self.mask) * s  # log det dx/du
        # log_abs_det_jacobian = log_s #.sum(-1, keepdim=True)
        log_abs_det_jacobian = -log_s

        return x, log_abs_det_jacobian


class MaskedLinear(nn.Linear):
    """ MADE building block layer """

    def __init__(self, input_size, n_outputs, mask, cond_label_size=None):
        super().__init__(input_size, n_outputs)
                                        
        self.register_buffer("mask", mask)

        self.cond_label_size = cond_label_size
        if cond_label_size is not None:
            self.cond_weight = nn.Parameter(
                torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)
            )

    def forward(self, x, y=None):
        out = F.linear(x, self.weight * self.mask, self.bias)
        if y is not None:
            out = out + F.linear(y, self.cond_weight)
        return out


class MADE(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        n_hidden,
        cond_label_size=None,
        activation="ReLU",
        input_order="sequential",
        input_degrees=None,
    ):
        """
        Args:
            input_size -- scalar; dim of inputs
            hidden_size -- scalar; dim of hidden layers
            n_hidden -- scalar; number of hidden layers
            activation -- str; activation function to use
            input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random)
                            or the order flipped from the previous layer in a stack of MADEs
            conditional -- bool; whether model is conditional
        """
        super().__init__()
        # base distribution for calculation of log prob under the model
        self.register_buffer("base_dist_mean", torch.zeros(input_size))
        self.register_buffer("base_dist_var", torch.ones(input_size))

        
        # create masks
        masks, self.input_degrees = create_masks(
            input_size, hidden_size, n_hidden, input_order, input_degrees
        )
        ##print('len(masks): ', len(masks))
        ##print('masks:')
        ##for mask in masks:
        ##    print(mask)
        ##print('self.input_degrees: ', self.input_degrees)
        ##print('input_order: ', input_order)

        # setup activation
        if activation == "ReLU":
            activation_fn = nn.ReLU()
        elif activation == "Tanh":
            activation_fn = nn.Tanh()
        else:
            raise ValueError("Check activation function.")

        # construct model
        self.net_input = MaskedLinear(
            input_size, hidden_size, masks[0], cond_label_size
        )
        self.net = []
        for m in masks[1:-1]:
            self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)]
        self.net += [
            activation_fn,
            # MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2, 1)), # original implementation
            MaskedLinear(hidden_size, input_size, masks[-1]), # ours. Doesn't output log(std) vals, just means
        ]
        self.net = nn.Sequential(*self.net)

    @property
    def base_dist(self):
        return Normal(self.base_dist_mean, self.base_dist_var)
    
    def set_base_dist(self, mu, sigma):
        self.base_dist_mean = mu
        self.base_dist_var = sigma

    def forward(self, x, y=None):
        # MAF eq 4 -- return mean and log std
        # m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1) # original
        m = self.net(self.net_input(x, y))
        loga = torch.zeros_like(m)

        u = (x - m) * torch.exp(-loga)
        # MAF eq 5
        log_abs_det_jacobian = -loga # Look into this!!!!!
        # print('Means:', m[0])
        # print('Std:', torch.exp(loga)[0])
        # print('Log Jacobian det::', log_abs_det_jacobian)
        return u, log_abs_det_jacobian

    def inverse(self, u, y=None, return_dist_params=None):
        # MAF eq 3
        # D = u.shape[-1]
        x = torch.zeros_like(u)
        ##print('x.shape: ', x.shape)
        # run through reverse model

        for i in self.input_degrees:
            # m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1) # original
            m = self.net(self.net_input(x, y))
            loga = torch.zeros_like(m) # log of ones vector
            
            x[..., i] = u[..., i] * torch.exp(loga[..., i]) + m[..., i]
            
        log_abs_det_jacobian = loga
        if return_dist_params:
            return x, log_abs_det_jacobian, (m, loga)
        return x, log_abs_det_jacobian

    def log_prob(self, x, y=None):
        u, log_abs_det_jacobian = self.forward(x, y)
        # print(u[0])
        return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=-1)


class Flow(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.__scale = None
        self.net = None

        # base distribution for calculation of log prob under the model
        self.register_buffer("base_dist_mean", torch.zeros(input_size))
        self.register_buffer("base_dist_var", torch.ones(input_size))

    @property
    def base_dist(self):
        return Normal(self.base_dist_mean, self.base_dist_var)

    @property
    def scale(self):
        return self.__scale

    @scale.setter
    def scale(self, scale):
        self.__scale = scale

    def forward(self, x, cond):
        if self.scale is not None:
            x /= self.scale
        u, log_abs_det_jacobian = self.net(x, cond)
        return u, log_abs_det_jacobian

    def inverse(self, u, cond, return_dist_params=False):
        if return_dist_params:
            x, log_abs_det_jacobian, list_dist_params = self.net.inverse(u, cond, return_dist_params)
        else:
            x, log_abs_det_jacobian = self.net.inverse(u, cond)
        
        if self.scale is not None:
            x *= self.scale
            log_abs_det_jacobian += torch.log(torch.abs(self.scale))
        
        if return_dist_params:
            return x, log_abs_det_jacobian, list_dist_params
        else:
            return x, log_abs_det_jacobian

    def log_prob(self, x, cond):
        u, sum_log_abs_det_jacobians = self.forward(x, cond)
        # print('log_prob: ', self.base_dist.log_prob(u))
        # print(sum_log_abs_det_jacobians)
        # print(u[0])
        # print('log(u)', torch.sum(self.base_dist.log_prob(u), dim=-1)[0:5])
        # print(u[:5])
        # print('Jackobian', torch.sum(sum_log_abs_det_jacobians, dim=-1)[0:5])
        return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=-1)

    def sample(self, sample_shape=torch.Size(), cond=None):
        if cond is not None:
            shape = cond.shape[:-1]
        else:
            shape = sample_shape

        u = self.base_dist.sample(shape)
        sample, _ = self.inverse(u, cond)
        return sample


class FlowSequential(nn.Sequential):
    """ Container for layers of a normalizing flow """

    def forward(self, x, y):
        sum_log_abs_det_jacobians = 0
        for module in self:
            x, log_abs_det_jacobian = module(x, y)
            sum_log_abs_det_jacobians += log_abs_det_jacobian
        return x, sum_log_abs_det_jacobians

    def inverse(self, u, y, return_dist_params=False):
        sum_log_abs_det_jacobians = 0
        list_dist_params = []
        for module in reversed(self):
            ##print('type(module): ', type(module)) # MADE instance
            if return_dist_params:
                u, log_abs_det_jacobian, dist_params = module.inverse(u, y, return_dist_params)
                list_dist_params.append(dist_params)
            else:
                u, log_abs_det_jacobian = module.inverse(u, y)

            sum_log_abs_det_jacobians += log_abs_det_jacobian

        x=u
        ##print('sum_log_abs_det_jacobians: ', sum_log_abs_det_jacobians)
        if return_dist_params:
            return x, sum_log_abs_det_jacobians, list_dist_params[-1]
        return x, sum_log_abs_det_jacobians
#
#
# class RealNVP(Flow):
#     def __init__(
#         self,
#         n_blocks,
#         input_size,
#         hidden_size,
#         n_hidden,
#         cond_label_size=None,
#         batch_norm=True,
#     ):
#         super().__init__(input_size)
#
#         # construct model
#         modules = []
#         mask = torch.arange(input_size).float() % 2
#         for i in range(n_blocks):
#             modules += [
#                 LinearMaskedCoupling(
#                     input_size, hidden_size, n_hidden, mask, cond_label_size
#                 )
#             ]
#             mask = 1 - mask
#             modules += batch_norm * [BatchNorm(input_size)]
#
#         self.net = FlowSequential(*modules)


class MAF(Flow):
    def __init__(
        self,
        n_blocks,
        input_size,
        hidden_size,
        n_hidden,
        cond_label_size=None,
        activation="ReLU",
        input_order="sequential",
        batch_norm=True,
    ):
        super().__init__(input_size)

        # construct model
        modules = []
        # Normalize input if batch norm is true and there is conditioning information
        modules += batch_norm * [BatchNorm(input_size)] if cond_label_size else []
        self.input_degrees = None
        for i in range(n_blocks):
            modules += [
                MADE(
                    input_size,
                    hidden_size,
                    n_hidden,
                    cond_label_size,
                    activation,
                    input_order,
                    self.input_degrees,
                )
            ]
            self.input_degrees = modules[-1].input_degrees.flip(0)
            #modules += batch_norm * [BatchNorm(input_size)]

        self.net = FlowSequential(*modules)

'''
class WrappedMAF(MaskedAutoregressiveFlow):
    def __init__(self,
                n_blocks,
                input_size,
                hidden_size,
                n_hidden,
                cond_label_size=None,
                activation="ReLU",
                periodic=False,
                input_order="sequential",
                batch_norm=True):
        
        features = input_size
        hidden_features = hidden_size
        num_layers = n_blocks
        num_blocks_per_layer = n_hidden
        use_residual_blocks = False#True
        use_random_masks = False
        use_random_permutations = True
        if 0:#periodic:
            activation = Snake()
            context_features = input_size
        else:
            activation = eval('F.' + activation.lower())
            context_features = cond_label_size
        dropout_probability = 0.0
        batch_norm_within_layers = batch_norm
        batch_norm_between_layers = False

        super().__init__(features=features, 
                        hidden_features=hidden_features,
                        context_features=context_features,
                        num_layers=num_layers,
                        periodic=periodic,
                        num_blocks_per_layer=num_blocks_per_layer,
                        use_residual_blocks=use_residual_blocks,
                        use_random_masks=use_random_masks,
                        use_random_permutations=use_random_permutations,
                        activation=activation,
                        dropout_probability=dropout_probability,
                        batch_norm_within_layers=batch_norm_within_layers,
                        batch_norm_between_layers=batch_norm_between_layers)


class Snake(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.Tensor([0.1]), requires_grad=True)
        # self.b = torch.nn.Parameter(torch.Tensor([1]), requires_grad=True)

    def forward(self, x):
        return torch.sin(self.a*x)**2/self.a#**2
        # return self.b*x + torch.sin(self.a*x)**2