# All the following flows are taken / modified from :
# https://github.com/ikostrikov/pytorch-flows/blob/master/flows.py
# Credit where credit is due.

import math
import numpy as np
import scipy as sp
import scipy.linalg
import torch
import torch.nn as nn
import torch.nn.functional as F

import helpers.layers as layers


def get_mask(in_features, out_features, in_flow_features, mask_type=None):
    """
    mask_type: input | None | output

    See Figure 1 for a better illustration:
    https://arxiv.org/pdf/1502.03509.pdf
    """
    if mask_type == 'input':
        in_degrees = torch.arange(in_features) % in_flow_features
    else:
        in_degrees = torch.arange(in_features) % (in_flow_features - 1)

    if mask_type == 'output':
        out_degrees = torch.arange(out_features) % in_flow_features - 1
    else:
        out_degrees = torch.arange(out_features) % (in_flow_features - 1)

    return (out_degrees.unsqueeze(-1) >= in_degrees.unsqueeze(0)).float()


def get_layer_fn(modifier_str):
    """Simple helper to get the modified layer."""
    layer_dict = {
        'none': nn.Linear,
        'gated': layers.GatedDense,
        'sine': layers.SineLinear,
        'spectralnorm': layers.SNLinear,
    }
    assert modifier_str in layer_dict, "unknown layer type {}".format(modifier_str)
    return layer_dict[modifier_str]


class MaskedLinear(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 mask,
                 cond_in_features=None,
                 bias=True):
        super(MaskedLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        if cond_in_features is not None:
            self.cond_linear = nn.Linear(
                cond_in_features, out_features, bias=False)

        self.register_buffer('mask', mask)

    def forward(self, inputs, cond_inputs=None):
        output = F.linear(inputs, self.linear.weight * self.mask,
                          self.linear.bias)
        if cond_inputs is not None:
            output += self.cond_linear(cond_inputs)
        return output


class MADESplit(nn.Module):
    """ An implementation of MADE
    (https://arxiv.org/abs/1502.03509).
    """

    def __init__(self,
                 num_inputs,
                 num_hidden,
                 num_cond_inputs=None,
                 s_act='tanh',
                 t_act='relu',
                 pre_exp_tanh=False):
        super(MADESplit, self).__init__()

        self.pre_exp_tanh = pre_exp_tanh

        input_mask = get_mask(num_inputs, num_hidden, num_inputs,
                              mask_type='input')
        hidden_mask = get_mask(num_hidden, num_hidden, num_inputs)
        output_mask = get_mask(num_hidden, num_inputs, num_inputs,
                               mask_type='output')

        act_func = layers.str_to_activ_module(s_act)
        self.s_joiner = MaskedLinear(num_inputs, num_hidden, input_mask,
                                     num_cond_inputs)

        self.s_trunk = nn.Sequential(act_func(),
                                     MaskedLinear(num_hidden, num_hidden,
                                                  hidden_mask), act_func(),
                                     MaskedLinear(num_hidden, num_inputs,
                                                  output_mask))

        act_func = layers.str_to_activ_module(t_act)
        self.t_joiner = MaskedLinear(num_inputs, num_hidden, input_mask,
                                     num_cond_inputs)

        self.t_trunk = nn.Sequential(act_func(),
                                     MaskedLinear(num_hidden, num_hidden,
                                                  hidden_mask), act_func(),
                                     MaskedLinear(num_hidden, num_inputs,
                                                  output_mask))

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            h = self.s_joiner(inputs, cond_inputs)
            m = self.s_trunk(h)

            h = self.t_joiner(inputs, cond_inputs)
            a = self.t_trunk(h)

            if self.pre_exp_tanh:
                a = torch.tanh(a)

            u = (inputs - m) * torch.exp(-a)
            return u, -a.sum(-1, keepdim=True)

        else:
            x = torch.zeros_like(inputs)
            for i_col in range(inputs.shape[1]):
                h = self.s_joiner(x, cond_inputs)
                m = self.s_trunk(h)

                h = self.t_joiner(x, cond_inputs)
                a = self.t_trunk(h)

                if self.pre_exp_tanh:
                    a = torch.tanh(a)

                x[:, i_col] = inputs[:, i_col] * torch.exp(
                    a[:, i_col]) + m[:, i_col]
            return x, -a.sum(-1, keepdim=True)


class MADE(nn.Module):
    """ An implementation of MADE
    (https://arxiv.org/abs/1502.03509).
    """

    def __init__(self,
                 num_inputs,
                 num_hidden,
                 num_cond_inputs=None,
                 act='relu',
                 pre_exp_tanh=False):
        super(MADE, self).__init__()
        act_func = layers.str_to_activ_module(act)

        input_mask = get_mask(
            num_inputs, num_hidden, num_inputs, mask_type='input')
        hidden_mask = get_mask(num_hidden, num_hidden, num_inputs)
        output_mask = get_mask(
            num_hidden, num_inputs * 2, num_inputs, mask_type='output')

        self.joiner = MaskedLinear(num_inputs, num_hidden, input_mask,
                                   num_cond_inputs)

        self.trunk = nn.Sequential(act_func(),
                                   MaskedLinear(num_hidden, num_hidden,
                                                hidden_mask), act_func(),
                                   MaskedLinear(num_hidden, num_inputs * 2,
                                                output_mask))

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            h = self.joiner(inputs, cond_inputs)
            m, a = self.trunk(h).chunk(2, 1)
            u = (inputs - m) * torch.exp(-a)
            return u, -a.sum(-1, keepdim=True)

        else:
            x = torch.zeros_like(inputs)
            for i_col in range(inputs.shape[1]):
                h = self.joiner(x, cond_inputs)
                m, a = self.trunk(h).chunk(2, 1)
                x[:, i_col] = inputs[:, i_col] * torch.exp(
                    a[:, i_col]) + m[:, i_col]
            return x, -a.sum(-1, keepdim=True)


class Sigmoid(nn.Module):
    def __init__(self):
        super(Sigmoid, self).__init__()

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            s = torch.sigmoid
            return s(inputs), torch.log(s(inputs) * (1 - s(inputs))).sum(
                -1, keepdim=True)
        else:
            return torch.log(inputs /
                             (1 - inputs)), -torch.log(inputs - inputs**2).sum(
                                 -1, keepdim=True)


class Logit(Sigmoid):
    def __init__(self):
        super(Logit, self).__init__()

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            return super(Logit, self).forward(inputs, 'inverse')
        else:
            return super(Logit, self).forward(inputs, 'direct')


class BatchNormFlow(nn.Module):
    """ An implementation of a batch normalization layer from
    Density estimation using Real NVP
    (https://arxiv.org/abs/1605.08803).
    """

    def __init__(self, num_inputs, momentum=0.0, eps=1e-5):
        super(BatchNormFlow, self).__init__()

        self.log_gamma = nn.Parameter(torch.zeros(num_inputs))
        self.beta = nn.Parameter(torch.zeros(num_inputs))
        self.momentum = momentum
        self.eps = eps

        self.register_buffer('running_mean', torch.zeros(num_inputs))
        self.register_buffer('running_var', torch.ones(num_inputs))

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            if self.training:
                self.batch_mean = inputs.mean(0)
                self.batch_var = (
                    inputs - self.batch_mean).pow(2).mean(0) + self.eps

                self.running_mean.mul_(self.momentum)
                self.running_var.mul_(self.momentum)

                self.running_mean.add_(self.batch_mean.data *
                                       (1 - self.momentum))
                self.running_var.add_(self.batch_var.data *
                                      (1 - self.momentum))

                mean = self.batch_mean
                var = self.batch_var
            else:
                mean = self.running_mean
                var = self.running_var

            x_hat = (inputs - mean) / var.sqrt()
            y = torch.exp(self.log_gamma) * x_hat + self.beta
            return y, (self.log_gamma - 0.5 * torch.log(var)).sum(
                -1, keepdim=True)
        else:
            if self.training:
                mean = self.batch_mean
                var = self.batch_var
            else:
                mean = self.running_mean
                var = self.running_var

            x_hat = (inputs - self.beta) / torch.exp(self.log_gamma)

            y = x_hat * var.sqrt() + mean

            return y, (-self.log_gamma + 0.5 * torch.log(var)).sum(
                -1, keepdim=True)


class ActNorm(nn.Module):
    """ An implementation of a activation normalization layer
    from Glow: Generative Flow with Invertible 1x1 Convolutions
    (https://arxiv.org/abs/1807.03039).
    """

    def __init__(self, num_inputs):
        super(ActNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(num_inputs))
        self.bias = nn.Parameter(torch.zeros(num_inputs))
        self.initialized = False

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if self.initialized is False:
            self.weight.data.copy_(torch.log(1.0 / (inputs.std(0) + 1e-12)))
            self.bias.data.copy_(inputs.mean(0))
            self.initialized = True

        if mode == 'direct':
            return (
                inputs - self.bias) * torch.exp(self.weight), self.weight.sum(
                    -1, keepdim=True).unsqueeze(0).repeat(inputs.size(0), 1)
        else:
            return inputs * torch.exp(
                -self.weight) + self.bias, -self.weight.sum(
                    -1, keepdim=True).unsqueeze(0).repeat(inputs.size(0), 1)


class InvertibleMM(nn.Module):
    """ An implementation of a invertible matrix multiplication
    layer from Glow: Generative Flow with Invertible 1x1 Convolutions
    (https://arxiv.org/abs/1807.03039).
    """

    def __init__(self, num_inputs):
        super(InvertibleMM, self).__init__()
        self.W = nn.Parameter(torch.Tensor(num_inputs, num_inputs))
        nn.init.orthogonal_(self.W)

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            return inputs @ self.W, torch.slogdet(
                self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1)
        else:
            return inputs @ torch.inverse(self.W), -torch.slogdet(
                self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1)


class LUInvertibleMM(nn.Module):
    """ An implementation of a invertible matrix multiplication
    layer from Glow: Generative Flow with Invertible 1x1 Convolutions
    (https://arxiv.org/abs/1807.03039).
    """

    def __init__(self, num_inputs):
        super(LUInvertibleMM, self).__init__()
        self.W = torch.Tensor(num_inputs, num_inputs)
        nn.init.orthogonal_(self.W)
        self.L_mask = torch.tril(torch.ones(self.W.size()), -1)
        self.U_mask = self.L_mask.t().clone()

        P, L, U = sp.linalg.lu(self.W.numpy())
        self.P = torch.from_numpy(P)
        self.L = nn.Parameter(torch.from_numpy(L))
        self.U = nn.Parameter(torch.from_numpy(U))

        S = np.diag(U)
        sign_S = np.sign(S)
        log_S = np.log(abs(S))
        self.sign_S = torch.from_numpy(sign_S)
        self.log_S = nn.Parameter(torch.from_numpy(log_S))

        self.eye = torch.eye(self.L.size(0))

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if str(self.L_mask.device) != str(self.L.device):
            self.L_mask = self.L_mask.to(self.L.device)
            self.U_mask = self.U_mask.to(self.L.device)
            self.eye = self.eye.to(self.L.device)
            self.P = self.P.to(self.L.device)
            self.sign_S = self.sign_S.to(self.L.device)

        L = self.L * self.L_mask + self.eye
        U = self.U * self.U_mask + torch.diag(
            self.sign_S * torch.exp(self.log_S))
        W = self.P @ L @ U

        if mode == 'direct':
            return inputs @ W, self.log_S.sum().unsqueeze(0).unsqueeze(
                0).repeat(inputs.size(0), 1)
        else:
            return inputs @ torch.inverse(
                W), -self.log_S.sum().unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1)


class Shuffle(nn.Module):
    """ An implementation of a shuffling layer from
    Density estimation using Real NVP
    (https://arxiv.org/abs/1605.08803).
    """

    def __init__(self, num_inputs):
        super(Shuffle, self).__init__()
        self.perm = np.random.permutation(num_inputs)
        self.inv_perm = np.argsort(self.perm)

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            return inputs[:, self.perm], torch.zeros(
                inputs.size(0), 1, device=inputs.device)
        else:
            return inputs[:, self.inv_perm], torch.zeros(
                inputs.size(0), 1, device=inputs.device)


class Reverse(nn.Module):
    """ An implementation of a reversing layer from
    Density estimation using Real NVP
    (https://arxiv.org/abs/1605.08803).
    """

    def __init__(self, num_inputs):
        super(Reverse, self).__init__()
        self.perm = np.array(np.arange(0, num_inputs)[::-1])
        self.inv_perm = np.argsort(self.perm)

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            return inputs[:, self.perm], torch.zeros(
                inputs.size(0), 1, device=inputs.device)
        else:
            return inputs[:, self.inv_perm], torch.zeros(
                inputs.size(0), 1, device=inputs.device)


class AdditiveCouplingLayer(nn.Module):
    """ An implementation of a additive coupling layer
    from RealNVP (https://arxiv.org/abs/1605.08803).
    """

    def __init__(self,
                 num_inputs,
                 num_hidden,
                 mask,
                 num_cond_inputs=None,
                 act='relu',
                 normalization_str='batchnorm',
                 layer_modifier='none'):
        super(AdditiveCouplingLayer, self).__init__()

        self.num_inputs = num_inputs
        self.mask = mask

        act_func = layers.str_to_activ_module(act)

        if num_cond_inputs is not None:
            total_inputs = num_inputs + num_cond_inputs
        else:
            total_inputs = num_inputs

        # Grab the type of linear layer
        layer_fn = get_layer_fn(layer_modifier)

        self.net = nn.Sequential(
            layers.BasicDenseBlock(total_inputs, num_hidden, layer_fn=layer_fn,
                                   normalization_str=normalization_str, activation_str=act),
            act_func(),
            # nn.Linear(total_inputs, num_hidden), act_func(),
            layers.BasicDenseBlock(num_hidden, num_hidden, layer_fn=layer_fn,
                                   normalization_str=normalization_str, activation_str=act),
            act_func(),
            # nn.Linear(num_hidden, num_hidden), act_func(),
            nn.Linear(num_hidden, num_inputs))

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        mask = self.mask

        masked_inputs = inputs * mask
        if cond_inputs is not None:
            masked_inputs = torch.cat([masked_inputs, cond_inputs], -1)

        if mode == 'direct':
            t = self.net(masked_inputs) * (1 - mask)
            return inputs + t, torch.zeros_like(t).sum(-1, keepdim=True)
        else:
            t = self.net(masked_inputs) * (1 - mask)
            return inputs - t, torch.zeros_like(t).sum(-1, keepdim=True)


class CouplingLayer(nn.Module):
    """ An implementation of a coupling layer
    from RealNVP (https://arxiv.org/abs/1605.08803).
    """

    def __init__(self,
                 num_inputs,
                 num_hidden,
                 mask,
                 num_cond_inputs=None,
                 s_act='tanh',
                 t_act='relu',
                 normalization_str='batchnorm',
                 layer_modifier='none'):
        super(CouplingLayer, self).__init__()

        self.num_inputs = num_inputs
        self.mask = mask

        s_act_func = layers.str_to_activ_module(s_act)
        t_act_func = layers.str_to_activ_module(t_act)

        if num_cond_inputs is not None:
            total_inputs = num_inputs + num_cond_inputs
        else:
            total_inputs = num_inputs

        # Grab the type of linear layer
        layer_fn = get_layer_fn(layer_modifier)

        self.scale_net = nn.Sequential(
            layers.BasicDenseBlock(total_inputs, num_hidden, layer_fn=layer_fn,
                                   normalization_str=normalization_str, activation_str=s_act),
            s_act_func(),
            # nn.Linear(total_inputs, num_hidden), s_act_func(),
            layers.BasicDenseBlock(num_hidden, num_hidden, layer_fn=layer_fn,
                                   normalization_str=normalization_str, activation_str=s_act),
            s_act_func(),
            # nn.Linear(num_hidden, num_hidden), s_act_func(),
            nn.Linear(num_hidden, num_inputs))
        self.translate_net = nn.Sequential(
            layers.BasicDenseBlock(total_inputs, num_hidden, layer_fn=layer_fn,
                                   normalization_str=normalization_str, activation_str=t_act),
            t_act_func(),
            # nn.Linear(total_inputs, num_hidden), t_act_func(),
            layers.BasicDenseBlock(num_hidden, num_hidden, layer_fn=layer_fn,
                                   normalization_str=normalization_str, activation_str=t_act),
            t_act_func(),
            # nn.Linear(num_hidden, num_hidden), t_act_func(),
            nn.Linear(num_hidden, num_inputs))

    def forward(self, inputs, cond_inputs=None, mode='direct'):
        mask = self.mask

        masked_inputs = inputs * mask
        if cond_inputs is not None:
            masked_inputs = torch.cat([masked_inputs, cond_inputs], -1)

        if mode == 'direct':
            log_s = self.scale_net(masked_inputs) * (1 - mask)
            t = self.translate_net(masked_inputs) * (1 - mask)
            s = torch.exp(log_s)
            return inputs * s + t, log_s.sum(-1, keepdim=True)
        else:
            log_s = self.scale_net(masked_inputs) * (1 - mask)
            t = self.translate_net(masked_inputs) * (1 - mask)
            s = torch.exp(-log_s)
            return (inputs - t) * s, -log_s.sum(-1, keepdim=True)


class FlowSequential(nn.Sequential):
    """ A sequential container for flows.
    In addition to a forward pass it implements a backward pass and
    computes log jacobians.
    """

    def forward(self, inputs, cond_inputs=None, mode='direct', logdets=None):
        """ Performs a forward or backward pass for flow modules.
        Args:
            inputs: a tuple of inputs and logdets
            mode: to run direct computation or inverse
        """
        self.num_inputs = inputs.size(-1)

        if logdets is None:
            logdets = torch.zeros(inputs.size(0), 1, device=inputs.device)

        assert mode in ['direct', 'inverse']
        if mode == 'direct':
            for module in self._modules.values():
                inputs, logdet = module(inputs, cond_inputs, mode)
                logdets += logdet
        else:
            for module in reversed(self._modules.values()):
                inputs, logdet = module(inputs, cond_inputs, mode)
                logdets += logdet

        return inputs, logdets

    def log_probs(self, inputs, cond_inputs=None):
        u, log_jacob = self(inputs, cond_inputs)
        log_probs = (-0.5 * u.pow(2) - 0.5 * math.log(2 * math.pi)).sum(
            -1, keepdim=True)
        return (log_probs + log_jacob).sum(-1, keepdim=True)

    def sample(self, num_samples=None, noise=None, cond_inputs=None):
        if noise is None:
            noise = torch.Tensor(num_samples, self.num_inputs).normal_()
        device = next(self.parameters()).device
        noise = noise.to(device)
        if cond_inputs is not None:
            cond_inputs = cond_inputs.to(device)
        samples = self.forward(noise, cond_inputs, mode='inverse')[0]
        return samples


def build_maf_flow(num_inputs, num_hidden, num_cond_inputs=None,
                   num_blocks=5, activation_str='relu'):
    """Simple helper to build num_blocks of a maf-based flow."""
    modules = []
    for _ in range(num_blocks):
        modules += [
            MADE(num_inputs, num_hidden, num_cond_inputs, act=activation_str),
            BatchNormFlow(num_inputs),
            Reverse(num_inputs)
        ]

    return FlowSequential(*modules)


def build_maf_split_flow(num_inputs, num_hidden, num_cond_inputs=None,
                         num_blocks=5, s_activation_str='tanh',
                         t_activation_str='relu'):
    """Simple helper to build num_blocks of a maf-split based flow."""
    modules = []
    for _ in range(num_blocks):
        modules += [
            MADESplit(num_inputs, num_hidden, num_cond_inputs,
                      s_act=s_activation_str, t_act=t_activation_str),
            BatchNormFlow(num_inputs),
            Reverse(num_inputs)
        ]

    return FlowSequential(*modules)


def build_maf_split_glow_flow(num_inputs, num_hidden, num_cond_inputs=None,
                              num_blocks=5, s_activation_str='tanh',
                              t_activation_str='relu'):
    """Simple helper to build num_blocks of a maf-split-glow (w/invertible MM) based flow."""
    modules = []
    for _ in range(num_blocks):
        modules += [
            MADESplit(num_inputs, num_hidden, num_cond_inputs,
                      s_act=s_activation_str, t_act=t_activation_str),
            BatchNormFlow(num_inputs),
            InvertibleMM(num_inputs)
        ]

    return FlowSequential(*modules)


def build_realnvp_flow(num_inputs, num_hidden, num_cond_inputs=None,
                       num_blocks=5, s_activation_str='tanh',
                       t_activation_str='relu', normalization_str='batchnorm',
                       cuda=False):
    """Simple helper to build num_blocks of a realNVP based flow."""
    modules = []

    device = torch.device("cuda:0" if cuda else "cpu")
    mask = torch.arange(0, num_inputs) % 2
    mask = mask.to(device).float()

    for _ in range(num_blocks):
        modules += [
            CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs,
                          s_act=s_activation_str, t_act=t_activation_str,
                          normalization_str=normalization_str),
            BatchNormFlow(num_inputs)
        ]
        mask = 1 - mask

    return FlowSequential(*modules)


def build_glow_flow(num_inputs, num_hidden, num_cond_inputs=None,
                    num_blocks=5, activation_str='relu',
                    normalization_str='batchnorm', layer_modifier='none',
                    cuda=False):
    modules = []

    device = torch.device("cuda:0" if cuda else "cpu")
    mask = torch.arange(0, num_inputs) % 2
    mask = mask.to(device).float()

    for _ in range(num_blocks):
        modules += [
            ActNorm(num_inputs),
            InvertibleMM(num_inputs),
            AdditiveCouplingLayer(num_inputs, num_hidden, mask,
                                  act=activation_str,
                                  num_cond_inputs=num_cond_inputs,
                                  normalization_str=normalization_str,
                                  layer_modifier=layer_modifier)
        ]
        mask = 1 - mask

    return FlowSequential(*modules)
