import abc

import torch
import torch.nn as nn
from torch.distributions import Uniform, TransformedDistribution
from torch.distributions.transforms import SigmoidTransform, AffineTransform
from torch.nn import Module, ReLU, Sequential, Linear, ModuleList


def build_nice_simple_network(
    latent_dim: int, hidden_dim: int, num_layers: int, output_dim: int = None
) -> Module:
    output_dim = output_dim or latent_dim
    modules = [Linear(latent_dim, hidden_dim)]
    modules += [Sequential(Linear(hidden_dim, hidden_dim), ReLU()) for _ in range(num_layers)]
    modules += [Linear(hidden_dim, output_dim)]
    return Sequential(*modules)


class CouplingLayer(nn.Module):
    def __init__(self, first=True):
        super().__init__()
        self.first = first

    def forward(self, x, log_det_J, reverse=False):
        if self.first:
            x1 = x[:, 0::2]
            x2 = x[:, 1::2]
        else:
            x1 = x[:, 1::2]
            x2 = x[:, 0::2]
        if reverse:
            y = self.reveres_law(x2, x1)
            jacobian = None
        else:
            y, jacobian = self.coupling_law(x2, x1)
        if self.first:
            return (
                torch.cat([x1.unsqueeze(-1), y.unsqueeze(-1)], dim=-1).reshape(x.shape[0], -1),
                jacobian,
            )
        else:
            return (
                torch.cat([y.unsqueeze(-1), x1.unsqueeze(-1)], dim=-1).reshape(x.shape[0], -1),
                jacobian,
            )

    @abc.abstractmethod
    def coupling_law(self, x2, x1):
        raise NotImplementedError()

    @abc.abstractmethod
    def reveres_law(self, x2, x1):
        raise NotImplementedError()


class AdditiveCoupling(CouplingLayer):
    def __init__(self, in_out_dim, mid_dim, hidden, **kwargs):
        super().__init__(**kwargs)
        self.network = build_nice_simple_network(in_out_dim // 2, mid_dim, hidden)

    def coupling_law(self, x2, x1):
        return x2 + self.network(x1), torch.tensor(0.0)

    def reveres_law(self, x2, x1):
        return x2 - self.network(x1)


class AffineCoupling(CouplingLayer):
    def __init__(self, in_out_dim, mid_dim, hidden, **kwargs):
        super().__init__(**kwargs)
        self.network = build_nice_simple_network(in_out_dim // 2, mid_dim, hidden, in_out_dim)

    def coupling_law(self, x2, x1):
        s, t = torch.chunk(self.network(x1), 2, dim=-1)
        s = torch.sigmoid(s)
        return x2 * torch.exp(s) + t, s.sum(dim=1)

    def reveres_law(self, x2, x1):
        s, t = torch.chunk(self.network(x1), 2, dim=-1)
        return (x2 - t) * torch.reciprocal(torch.exp(torch.sigmoid(s)))


class Scaling(nn.Module):
    def __init__(self, dim):
        super(Scaling, self).__init__()
        self.scale = nn.Parameter(torch.zeros(dim), requires_grad=True)
        self.eps = 1e-5

    def forward(self, x, reverse=False):
        scale = torch.exp(self.scale) + self.eps
        if reverse:
            return x @ torch.diag(torch.reciprocal(scale))
        else:
            return x @ torch.diag(scale)


class NICE(nn.Module):
    def __init__(self, prior, coupling, coupling_type, in_out_dim, mid_dim, hidden, device):
        super().__init__()
        self.device = device
        if prior == "gaussian":
            self.prior = torch.distributions.Normal(
                torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)
            )
        elif prior == "logistic":
            self.prior = TransformedDistribution(
                Uniform(torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)),
                [
                    SigmoidTransform().inv,
                    AffineTransform(
                        loc=torch.tensor(0.0, device=device),
                        scale=torch.tensor(1.0, device=device),
                    ),
                ],
            )
        else:
            raise ValueError("Prior not implemented.")
        if coupling_type == "additive":
            coupling_type_klass = AdditiveCoupling
        elif coupling_type == "affine":
            coupling_type_klass = AffineCoupling
        else:
            raise NotImplementedError("No such coupling")

        self.in_out_dim = in_out_dim
        self.coupling_type = coupling_type
        self.coupling = ModuleList(
            [
                coupling_type_klass(in_out_dim, mid_dim, hidden, first=i % 2 == 0)
                for i in range(coupling)
            ]
        )
        self.scaling = Scaling(in_out_dim)

    def f_inverse(self, z):
        with torch.no_grad():
            data = self.scaling(z, reverse=True)
            for couple_layer in reversed(self.coupling):
                data, _ = couple_layer(data, None, reverse=True)
        return data

    def f(self, x):
        original_shape = x.shape
        data = x.reshape(original_shape[0], -1)
        sum_log_j = 0
        for couple_layer in self.coupling:
            data, j = couple_layer(data, None)
            sum_log_j += j
        data = self.scaling(data)
        return data, torch.sum(self.scaling.scale) + sum_log_j

    def log_prob(self, x):
        z, log_det_J = self.f(x)
        # log det for rescaling from [0.256] (after dequantization) to [0,1]
        log_det_J -= torch.log(torch.tensor(256.0)) * self.in_out_dim
        log_ll = torch.sum(self.prior.log_prob(z), dim=1)
        return log_ll + log_det_J

    def sample(self, size):
        z = self.prior.sample((size, self.in_out_dim)).to(self.device)
        return z

    def forward(self, x):
        return self.log_prob(x)
