"""Utility classes for NICE.
"""

import torch
import torch.nn as nn
from torch.nn.modules.linear import Linear
import numpy as np

class ZeroLinear(Linear):
    def reset_parameters(self):
        torch.nn.init.normal_(self.weight, mean=0.0, std=1e-6)#nn.init.zeros_(self.weight)
        torch.nn.init.normal_(self.bias, mean=0.0, std=1e-6)#nn.init.zeros_(self.bias)

"""Additive coupling layer.
"""
class Coupling(nn.Module):
    def __init__(self, in_out_dim, mid_dim, hidden, mask_config):
        """Initialize a coupling layer.

        Args:
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(Coupling, self).__init__()
        self.mask_config = mask_config

        self.in_block = nn.Sequential(
            nn.Linear(in_out_dim//2, mid_dim),#ZeroLinear(in_out_dim//2, mid_dim),#
            nn.ReLU())
        self.mid_block = nn.ModuleList([
            nn.Sequential(
                nn.Linear(mid_dim, mid_dim),#ZeroLinear(mid_dim, mid_dim),#
                nn.ReLU()) for _ in range(hidden - 1)])
        self.out_block = nn.Linear(mid_dim, in_out_dim//2)#ZeroLinear(mid_dim, in_out_dim//2)#

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor.
        """
        [B, W] = list(x.size())
        x = x.reshape((B, W//2, 2))
        if self.mask_config:
            on, off = x[:, :, 0], x[:, :, 1]
        else:
            off, on = x[:, :, 0], x[:, :, 1]

        off_ = self.in_block(off)
        for i in range(len(self.mid_block)):
            off_ = self.mid_block[i](off_)
        shift = self.out_block(off_)
        if reverse:
            on = on - shift
        else:
            on = on + shift

        if self.mask_config:
            x = torch.stack((on, off), dim=2)
        else:
            x = torch.stack((off, on), dim=2)
        return x.reshape((B, W))

"""Additive coupling layer.
"""
class RandomCoupling(nn.Module):
    def __init__(self, in_out_dim, mid_dim, hidden, mask_config, permutations = None):
        """Initialize a coupling layer.

        Args:
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(RandomCoupling, self).__init__()
        self.mask_config = mask_config

        if permutations:
            self.permutation = permutations[0]
            self.reverse_permutation = permutations[1]
        else:
            self.permutation = np.random.permutation(in_out_dim)#
            self.reverse_permutation = torch.LongTensor(np.arange(len(self.permutation))[np.argsort(self.permutation)])
            self.permutation = torch.LongTensor(self.permutation)

        self.in_block = nn.Sequential(
            nn.Linear(in_out_dim//2, mid_dim),#ZeroLinear(in_out_dim//2, mid_dim),#
            nn.ReLU())
        self.mid_block = nn.ModuleList([
            nn.Sequential(
                nn.Linear(mid_dim, mid_dim),#ZeroLinear(mid_dim, mid_dim),#
                nn.ReLU()) for _ in range(hidden - 1)])
        self.out_block = nn.Linear(mid_dim, in_out_dim//2)#nn.Sequential(nn.Linear(mid_dim, in_out_dim//2), nn.BatchNorm1d(num_features=in_out_dim//2))#ZeroLinear(mid_dim, in_out_dim//2)#

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor.
        """
        [B, W] = list(x.size())
        x = x[:, self.permutation].reshape((B, W//2, 2))
        if self.mask_config:
            on, off = x[:, :, 0], x[:, :, 1]
        else:
            off, on = x[:, :, 0], x[:, :, 1]

        off_ = self.in_block(off)
        for i in range(len(self.mid_block)):
            off_ = self.mid_block[i](off_)
        shift = self.out_block(off_)
        if reverse:
            on = on - shift
        else:
            on = on + shift

        if self.mask_config:
            x = torch.stack((on, off), dim=2)
        else:
            x = torch.stack((off, on), dim=2)
        return x.reshape((B, W))[:, self.reverse_permutation]

"""Log-scaling layer.
"""
class Scaling(nn.Module):
    def __init__(self, dim):
        """Initialize a (log-)scaling layer.

        Args:
            dim: input/output dimensions.
        """
        super(Scaling, self).__init__()
        self.scale = nn.Parameter(
            torch.zeros((1, dim)), requires_grad=True)

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor and log-determinant of Jacobian.
        """
        log_det_J = torch.sum(self.scale)
        if reverse:
            x = x * torch.exp(-self.scale)
        else:
            x = x * torch.exp(self.scale)
        return x, log_det_J

"""NICE main model.
"""
class NICE(nn.Module):
    def __init__(self, prior, prior_noise, coupling, 
        in_out_dim, mid_dim, hidden, mask_config, permutations_list=None, equalize_perms=False):
        """Initialize a NICE.

        Args:
            prior: prior distribution over latent space Z.
            coupling: number of coupling layers.
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(NICE, self).__init__()
        self.prior = prior
        self.prior_noise = prior_noise
        self.in_out_dim = in_out_dim
        self.permutations = None
        if permutations_list:
            self.permutations = permutations_list
        else:
            if equalize_perms:
                permutation = np.arange(in_out_dim)
            else:
                permutation = np.random.permutation(in_out_dim)
            permutations = [permutation for i in range(coupling)]#[np.random.permutation(in_out_dim) for i in range(coupling)]#[np.arange(in_out_dim) for i in range(coupling)]#
            self.permutations = [[torch.LongTensor(permutation), torch.LongTensor(np.arange(len(permutation))[np.argsort(permutation)])] for permutation in permutations]

        self.coupling = nn.ModuleList([
            RandomCoupling(in_out_dim=in_out_dim, 
                     mid_dim=mid_dim, 
                     hidden=hidden, 
                     mask_config=(mask_config+i)%2, permutations = self.permutations[i]) \
            for i in range(coupling)])
        self.scaling = Scaling(in_out_dim)

    def g(self, z):
        """Transformation g: Z -> X (inverse of f).

        Args:
            z: tensor in latent space Z.
        Returns:
            transformed tensor in data space X.
        """
        x, _ = self.scaling(z, reverse=True)
        for i in reversed(range(len(self.coupling))):
            x = self.coupling[i](x, reverse=True)
        return x

    def f(self, x):
        """Transformation f: X -> Z (inverse of g).

        Args:
            x: tensor in data space X.
        Returns:
            transformed tensor in latent space Z.
        """
        for i in range(len(self.coupling)):
            x = self.coupling[i](x)
        return self.scaling(x)

    def log_prob(self, x):
        """Computes data log-likelihood.

        (See Section 3.3 in the NICE paper.)

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
        z, log_det_J = self.f(x)
        log_ll = torch.sum(self.prior.log_prob(z), dim=1)
        return log_ll + log_det_J

    def sample(self, size):
        """Generates samples.

        Args:
            size: number of samples to generate.
        Returns:
            samples from the data space X.
        """
        z = self.prior.sample((size, self.in_out_dim))
        return self.g(z)

    def log_prob_noise(self, x):
        """Computes data log-likelihood.

        (See Section 3.3 in the NICE paper.)

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
        z, log_det_J = self.f(x)
        log_ll = torch.sum(self.prior_noise.log_prob(z), dim=1)
        return log_ll + log_det_J

    def sample_noise(self, size):
        """Generates samples.

        Args:
            size: number of samples to generate.
        Returns:
            samples from the data space X.
        """
        z = self.prior_noise.sample((size, self.in_out_dim))
        return self.g(z)

    def forward(self, x):
        """Forward pass.

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
        return self.log_prob(x)
