import torch
import torch.nn as nn
from torch.nn import Module
from torch.nn import functional as F
import numpy as np
from inspect import signature
import math


def lower_diagonal_mask(features, strict=True):
    idx = torch.arange(features)
    if strict:
        mask = (idx[:, None] > idx[None, :])
    else:
        mask = (idx[:, None] >= idx[None, :])
    return mask


class BlockLowerTriangular(Module):
    def __init__(self, features, in_block=1, out_block=1,
                 strict=True, bias=False):
        super(BlockLowerTriangular, self).__init__()
        self.in_block = in_block
        self.out_block = out_block
        self.strict = strict
        self.transform = nn.Linear(
            in_features=features * in_block,
            out_features=features * out_block,
            bias=bias
        )
        mask = lower_diagonal_mask(features, strict=strict)
        if in_block > 1 or out_block > 1:
            mask = mask[:, None, :, None] \
                .repeat(1, out_block, 1, in_block) \
                .view(features * out_block, features * in_block)
        self.strict = strict
        assert(self.transform.weight.size() == mask.size())
        self.register_buffer("mask", mask)
        self.reset_parameters()

    def reset_parameters(self):
        self.transform.reset_parameters()
        inv_stdv = math.sqrt(self.transform.weight.size(1))
        row_stdv = torch.sqrt(self.mask.float().sum(dim=1, keepdim=True))
        self.transform.weight.data.mul_(inv_stdv / (1e-5 + row_stdv))

        if self.transform.bias is not None:
            self.transform.bias.data.zero_()

        if self.in_block == 1 and self.out_block == 1:
            tril_weight = torch.tril(self.transform.weight,
                                     diagonal=-1 if self.strict else 0)
            mask_weight = self.transform.weight.masked_fill(~self.mask, 0.)
            equal = (tril_weight == mask_weight)
            assert(equal.all())
        self.transform.weight.data.masked_fill_(~self.mask, 0.)

    def forward(self, input, mask=True):
        # if self.in_block == 1 or self.out_block == 1:
        if False:
            weight = torch.tril(self.transform.weight,
                                diagonal=-1 if self.strict else 0)
        else:
            weight = self.transform.weight.masked_fill(~self.mask, 0.)
        return F.linear(
            input,
            weight,
            self.transform.bias
        )


class AutoregressiveMLP(Module):
    # Fixed sized hidden layers for now.
    def __init__(self, in_features, context_features,
                 layers, hidden_upscale=1, output_upscale=1,
                 activation=nn.Tanh()):
        super(AutoregressiveMLP, self).__init__()
        self.autoreg = nn.ModuleList(
            [BlockLowerTriangular(
                features=in_features,
                in_block=1,
                out_block=hidden_upscale,
                strict=True)] +
            [BlockLowerTriangular(
                features=in_features,
                in_block=hidden_upscale,
                out_block=hidden_upscale,
                strict=False) for _ in range(1, layers - 1)] +
            [BlockLowerTriangular(
                features=in_features,
                in_block=hidden_upscale,
                out_block=output_upscale,
                strict=False)]
        )

        self.ctx_transforms = nn.Linear(
            context_features,
            hidden_upscale * in_features * (layers - 1)
        )

        # self.autoreg[-1].weight.data.fill_(0.)
        # self.ctx_transforms[-1].weight.data.fill_(0.)
        self.ctx_transforms.bias.data.fill_(0.)

        self.activation = activation
        self.layers = layers

    def forward(self, input, context):
        prev_z = input
        t_ctx = torch.chunk(
            self.ctx_transforms(context),
            self.layers - 1,
            dim=-1
        )
        for i in range(self.layers):
            prev_z = self.autoreg[i](prev_z)
            if i < self.layers - 1:
                prev_z = self.activation(prev_z + t_ctx[i])
        return prev_z

class Bijector(Module):
    def __init__(self, dims=-1):
        super(Bijector, self).__init__()
        self.dims = dims

    def forward(self, x, conditioned=None):
        raise NotImplementedError

    def inverse(self, y, conditioned=None):
        raise NotImplementedError


    def forward_log_det_jacobian(self, x, y=None, conditioned=None):
        if y is None:
            y = self.forward(x, conditioned=None)
        return -self.inverse_log_det_jacobian(y)

    def inverse_log_det_jacobian(self, y, x=None, conditioned=None):
        if x is None:
            x = self.inverse(y, conditioned=None)
        return -self.forward_log_det_jacobian(x)

    def forward_and_invlogdet(self, x, context=None):
        y = self.forward(x, context)
        return y, self.inverse_log_det_jacobian(y, context)

    def inverse_and_invlogdet(self, y, context=None):
        x = self.inverse(y, context)
        return x, self.inverse_log_det_jacobian(y, context)



class Exp(Bijector):
    def __init__(self, dims=-1):
        super(Exp, self).__init__(dims=dims)

    def forward(self, x):
        return torch.exp(x)

    def inverse(self, y):
        return torch.log(x)

    def inverse_log_det_jacobian(self, y, x=None):
        if x is None:
            result = -torch.log(y)
        else:
            result = -x
        return result.sum(dim=self.dims)


class Sigmoid(Bijector):
    def __init__(self, dims=-1):
        super(Sigmoid, self).__init__(dims=dims)

    def forward(self, x, conditioned=None):
        return torch.sigmoid(x)

    def inverse(self, y, conditioned=None):
        return torch.log(y) - torch.log(1 - y)

    def inverse_log_det_jacobian(self, y, conditioned=None):
        return (-torch.log(y) - torch.log(1 - y)).sum(self.dims)

def soft_clamp(x, min=-5., max=5.):
     x_ = max - F.softplus(max - x)
     x_ = F.softplus(x_ - min) + min
     return x_

class NLSq(Bijector):
    def __init__(self, dims=-1):
        self.dims = -1
        super(NLSq, self).__init__(dims=dims)
        self.register_buffer('log_const',
                             torch.tensor(math.log(8 * math.sqrt(3) / 9.)))

    def _safe(self, a, log_b, lin_c, log_d, g):
        # a = soft_clamp(a)
        # log_b = soft_clamp(log_b)
        # log_d = soft_clamp(log_d)
        # lin_c = soft_clamp(lin_c)
        # g = soft_clamp(g)
        return a, log_b, lin_c, log_d, g

    def _constrain_params(self, a, log_b, lin_c, log_d, g, log_eps=-16.):
        b = torch.exp(log_b)
        d = torch.exp(log_d)
        log_factor = self.log_const + log_b - log_d
        mask = log_factor < log_eps
        factor = torch.exp(log_factor.clamp(min=log_eps)).masked_fill(mask, 0.)
        c = torch.tanh(lin_c) * factor
        mask = (~torch.isfinite(a) | 
                ~torch.isfinite(b) |
                ~torch.isfinite(c) |
                ~torch.isfinite(d) |
                ~torch.isfinite(g))
        if mask.any():
            print("a", a[mask])
            print("b", b[mask])
            print("c", c[mask])
            print("d", d[mask])
            print("g", g[mask])
        return a, b, c, d, g

    def _nlsq_logdet(self, x, a, b, c, d, g):
        denom = d * x + g
        term_1 = b
        term_2 = (d * 2 * c * denom) / (1 + denom ** 2)**2
        pre_log = term_1 - term_2.masked_fill(d > 1e5, 0.)
        mask = ~torch.isfinite(pre_log)
        return torch.log(pre_log).sum(self.dims)

    def _nlsq_logdet_numerical(self, x, a, log_b, c, log_d, g):
        d = torch.exp(log_d)
        denom = d * x + g
        log_term_1 = log_b
        sign_term_2 = torch.sign(c) * torch.sign(denom)
        """
        print('a', a.min(), a.max())
        print('log_b', log_b.min(), log_b.max())
        print('c', c.min(), c.max())
        print('log_d', log_d.min(), log_d.max())
        print('g', g.min(), g.max())
        """
        log_unsigned_term_2 = (
            log_d  +
            torch.log(torch.abs(2 * c).clamp(min=1e-6)) +
            torch.log(torch.abs(denom).clamp(min=1e-6))
            - 2 * torch.log(1 + denom ** 2)
        )
        norm_log_term_2 = log_unsigned_term_2 - log_term_1
        norm_unsigned_term_2 = torch.exp(norm_log_term_2).masked_fill(norm_log_term_2 < -16, 0.)
        full_out = (torch.log(1 - sign_term_2 * norm_unsigned_term_2) + log_term_1).sum(self.dims)

        inf_case = (d > 1e5).any(dim=self.dims)
        full_out[inf_case] = log_term_1.sum(self.dims)[inf_case]

        return full_out



    def forward_and_invlogdet(self, x, a, log_b, lin_c, log_d, g):
        a, log_b, lin_c, log_d, g = self._safe(a, log_b, lin_c, log_d, g)
        a, b, c, d, g = self._constrain_params(a, log_b, lin_c, log_d, g)
        y = a + b * x + \
            c / (1. + (d * x + g) ** 2)
        ldji = self._nlsq_logdet_numerical(x, a, log_b, c, log_d, g)
        ldji_ = self._nlsq_logdet(x, a, b, c, d, g)
        # print("new", ldji)
        # print("old", ldji_)
        return y, ldji

    def forward(self, x, a, log_b, lin_c, log_d, g):
        return self.forward_and_invlogdet(x, a, log_b, lin_c, log_d, g)[0]

    def inverse_and_invlogdet(self, y, a, log_b, lin_c, log_d, g):
        a, log_b, lin_c, log_d, g = self._safe(a, log_b, lin_c, log_d, g)
        a, b, c, d, g = self._constrain_params(a, log_b, lin_c, log_d, g)
        x = self._inv_nlsq(y, a, b, c, d, g)
        return x, self._nlsq_logdet(x, a, b, c, d, g)

    def inverse(self, y, a, log_b, lin_c, log_d, g):
        return self.inverse_and_invlogdet(y, a, log_b, lin_c, log_d, g)[0]

    def _cube_root(self, x):
        return torch.sign(x) * torch.abs(x) ** (1 / 3.)

    def _div(self, a, b):
        return torch.sign(a) * torch.sign(b) * torch.exp(
            torch.log(torch.abs(a)) - torch.log(torch.abs(b))
        )

    def _inv_nlsq(self, y, a, b, c, d, g):
        y_minus_a = y - a
        """
        A = -b * (d**2)
        B = y_minus_a * (d**2) - 2 * d * g * b
        C = y_minus_a * 2 * d * g - b * (g**2 + 1)
        D = y_minus_a * (g**2 + 1) - c
        """
        B_ = y_minus_a * (d ** 2) - 2 * d * g * b
        abs_B = torch.abs(B_)
        A = self._div(-b * (d ** 2), abs_B)
        C = self._div(y_minus_a * 2 * d * g - b * (g ** 2 + 1), abs_B)
        D = self._div(y_minus_a * (g ** 2 + 1) - c, abs_B)
        B = torch.sign(B_)

        delta_0 = B ** 2. - 3 * A * C
        delta_1 = 2 * B ** 3. - (9 * A * B * C) + 27 * (A ** 2) * D
        det = delta_1 ** 2. - 4 * (delta_0 ** 3)

        assert (det > 0.).all()
        root_det = torch.sqrt(det)
        cubed_C = (delta_1 - root_det) / 2
        cubed_C[delta_0 == 0] = delta_1[delta_0 == 0]
        big_C = self._cube_root(cubed_C)
        delta_over_C = (delta_0 / big_C).masked_fill(big_C == 0., 0)
        x = -(B + big_C + delta_over_C) / (3. * A)
        return x


class ScaleShift(Bijector):
    def __init__(self, dims=-1):
        self.dims=-1
        super(ScaleShift, self).__init__(dims=dims)

    def forward(self, x, log_w, b):
        # TODO: check if needed
        # log_w = torch.clamp(log_w, min=-12, max=12)
        return x * torch.exp(log_w) + b

    def inverse(self, y, log_w, b):
        # TODO: check if needed.
        # log_w = torch.clamp(log_w, min=-12, max=12)
        return (y - b) / torch.exp(log_w)

    def inverse_log_det_jacobian(self, y, log_w, b):
        if len(log_w.size()) != 0:
            ldji: torch.Tensor = -log_w.sum(self.dims)
            # print(ldji.size(), y.size())
            # ldji = ldji.expand_as(y[..., 0])
            ldji = ldji + torch.zeros_like(y[..., 0])
        else:
            ldji = -log_w * y.size(self.dims)
        return ldji

    def forward_and_invlogdet(self, x, log_w, b):
        y = self.forward(x, log_w, b)
        return y, self.inverse_log_det_jacobian(y, log_w, b)

    def inverse_and_invlogdet(self, y, log_w, b):
        x = self.inverse(y, log_w, b)
        return x, self.inverse_log_det_jacobian(x, log_w, b)



class ContextScaleShift(Bijector):
    def __init__(self, in_features, out_features):
        super(ContextScaleShift, self).__init__(dims=-1)
        self.scaleshift_op = ScaleShift(dims=-1)
        self.transform = nn.Linear(in_features, 2 * out_features)

    def get_params(self, ctx):
        log_w, b = torch.chunk(self.transform(ctx), 2, dim=-1)
        log_w = F.softplus(log_w + 32.) - 32.
        return log_w, b
    def forward(self, x, ctx):
        return self.forward_and_invlogdet(x, ctx)[0]

    def inverse(self, y, ctx):
        log_w, b = self.get_params(ctx)
        return self.scaleshift_op.inverse(y, log_w, b)

    def inverse_and_invlogdet(self, y, ctx):
        log_w, b = self.get_params(ctx)
        return (self.scaleshift_op.inverse(y, log_w, b),
                self.scaleshift_op.inverse_log_det_jacobian(y, log_w, b))

    def inverse_log_det_jacobian(self, y, ctx):
        log_w, b = self.get_params(ctx)
        return self.scaleshift_op.inverse_log_det_jacobian(y, log_w, b)

    def forward_and_invlogdet(self, x, ctx):
        log_w, b = self.get_params(ctx)
        return self.scaleshift_op.forward_and_invlogdet(x, log_w, b)




class Gating(Bijector):
    def __init__(self, dims=-1):
        super(Gating, self).__init__(dims=dims)

    def forward(self, x, x_new, lin_gate):
        g = torch.sigmoid(lin_gate)
        # g = (1 - (1e-5 / 2.)) * g + 1e-5
        return g * x + (1 - g) * x_new

    def inverse(self, y, x_new, lin_gate):
        g = torch.sigmoid(lin_gate)
        return (y - (1 - g) * x_new) / g

    def inverse_log_det_jacobian(self, y, x_new, lin_gate):
        return F.softplus(-lin_gate).sum(self.dims)


class IAF(Bijector):
    def __init__(self, out_features, in_features, n_layers,
                 block_size=1):
        super(IAF, self).__init__(dims=-1)
        self.mlp = AutoregressiveMLP(out_features, in_features, n_layers,
                                     hidden_upscale=block_size,
                                     output_upscale=2)
        self.gate = Gating(dims=-1)

    def forward_and_invlogdet(self, x, ctx):
        output = self.mlp(x, ctx)

        lin_gate, x_new = output[..., ::2], output[..., 1::2]
        y = self.gate(x, x_new, lin_gate - 3.5)
        return y, self.gate.inverse_log_det_jacobian(y, x_new, lin_gate)

    def forward(self, x, ctx):
        return self.forward_and_invlogdet(x, ctx)[0]

    def forward_log_det_jacobian(self, x, ctx):
        return -self.forward_and_invlogdet(x, ctx)[1]


class RealNVP(Bijector):
    def __init__(self, transform, transform_2=None):
        super(RealNVP, self).__init__(dims=-1)
        self.scaleshift = ScaleShift()
        self.transform_1 = transform
        self.transform_2 = transform_2


    def forward_and_invlogdet(self, x :torch.Tensor, ctx :torch.Tensor):
        x_, _x = x.chunk(2, dim=-1)
        log_w, b = self.transform_1(x_, ctx)
        _y, _ldji = self.scaleshift.forward_and_invlogdet(_x, log_w, b)
        if self.transform_2 is not None:
            log_w, b = self.transform_2(_y, ctx)
            y_, ldji_ = self.scaleshift.forward_and_invlogdet(x_, log_w, b)
        else:
            y_ = x_
            ldji_ = 0.
        y = torch.cat((y_, _y), dim=-1)
        return y, _ldji + ldji_

    def inverse_and_invlogdet(self, y :torch.Tensor, ctx :torch.Tensor):
        y_, _y = y.chunk(2, dim=-1)
        if self.transform_2 is not None:
            log_w, b = self.transform_2(_y, ctx)
            x_ = self.scaleshift.inverse(y_, log_w, b)
            ldji_ = self.scaleshift.inverse_log_det_jacobian(x_, log_w, b)
        else:
            x_ = y_
            ldji_ = 0.

        log_w, b = self.transform_1(x_, ctx)
        _x = self.scaleshift.inverse(_y, log_w, b)
        _ldji = self.scaleshift.inverse_log_det_jacobian(_x, log_w, b)
        # expanded_y_ = y_ + torch.zeros_like(_x)
        # print(torch.allclose(expanded_y, expanded_y_))
        x = torch.cat((x_, _x), dim=-1)
        return x, _ldji + ldji_

    def inverse(self, y :torch.Tensor, ctx :torch.Tensor):
        return self.inverse_and_invlogdet(y, ctx)[0]

    def forward(self, x :torch.Tensor, ctx :torch.Tensor):
        return self.forward_and_invlogdet(x, ctx)[0]

    def forward_log_det_jacobian(self, x, ctx):
        return -self.forward_and_invlogdet(x, ctx)[1]


class PermuteDimensions(Bijector):
    def __init__(self, idx_order):
        super(PermuteDimensions, self).__init__(dims=-1)
        self.register_buffer("idx", idx_order)
        self.register_buffer("idx_inverse", torch.argsort(self.idx))

    def forward(self, x):
        return x[..., self.idx]

    def inverse(self, y):
        return y[..., self.idx_inverse]

    def forward_and_invlogdet(self, x):
        y = self.forward(x)
        return y, self.inverse_log_det_jacobian(y)

    def inverse_and_invlogdet(self, y):
        x = self.inverse(y)
        return x, self.inverse_log_det_jacobian(x)

    def inverse_log_det_jacobian(self, y):
        return torch.zeros_like(y.sum(-1))



class Sequential(Bijector):
    def __init__(self, *bijectors):
        super(Sequential, self).__init__(dims=-1)
        self.bijectors = nn.ModuleList(bijectors)
        self.argcount = [len(signature(m.forward_and_invlogdet).parameters)
                         for m in self.bijectors]

    def forward(self, x, ctx):
        return self.forward_and_invlogdet(x, ctx)[0]

    def inverse(self, y, ctx):
        return self.inverse_and_invlogdet(y, ctx)[0]

    def forward_and_invlogdet(self, x, ctx):
        prev = x
        log_det_jac_inv = 0. # torch.zeros_like(z.sum(-1))
        for fun, argc in zip(self.bijectors, self.argcount):
            prev, ld = fun.forward_and_invlogdet(*((prev, ctx)[:argc]))
            log_det_jac_inv += ld
        return prev, log_det_jac_inv

    def inverse_and_invlogdet(self, y, ctx):
        prev = y
        log_det_jac_inv = 0. # torch.zeros_like(y.sum(-1))
        for fun, argc in zip(self.bijectors[::-1], self.argcount[::-1]):
            prev, ld = fun.inverse_and_invlogdet(*((prev, ctx)[:argc]))
            log_det_jac_inv += ld
        return prev, log_det_jac_inv





if __name__ == "__main__":
    input_size = 8
    z0 = torch.randn(1, input_size)


    class NVPTransform(nn.Module):
        def __init__(self, size_1, size_2, out_size, hidden_size=None):
            super(NVPTransform, self).__init__()
            if hidden_size is None:
                self.hidden_size = max(size_1, size_2)
            else:
                self.hidden_size = hidden_size
            self.in_combine_1 = nn.Linear(size_1, self.hidden_size)
            self.in_combine_2 = nn.Linear(size_2, self.hidden_size)
            self.out_transform = nn.Sequential(
                nn.GELU(),
                nn.Linear(self.hidden_size, out_size * 2)
            )
            nn.init.zeros_(self.out_transform[-1].weight)

        def forward(self, in1, in2=None):
            lin_hidden = (self.in_combine_1(in1) +
                          self.in_combine_2(in2)) / 2.
            log_scale, shift = self.out_transform(lin_hidden).chunk(2, dim=-1)
            log_scale = F.softplus(log_scale + 32) - 32
            return log_scale, shift


    transform = Sequential(
        PermuteDimensions(torch.from_numpy(np.random.permutation(input_size))),
        ContextScaleShift(1, input_size),
        RealNVP(transform=NVPTransform(input_size // 2, 1, input_size // 2),
                transform_2=NVPTransform(input_size // 2, 1, input_size // 2))
    )
    zeros = torch.zeros((1, 1))
    print(z0)
    z = transform(z0, zeros)
    print(z)
    print(transform.inverse(z, zeros))
    exit()


if __name__ == "__main__":
    import autoencoder
    ss = ContextScaleShift(5, 5)
    x = torch.Tensor(10, 5)
    x.data.normal_()
    ctx = torch.Tensor(10, 5)
    ctx.data.normal_()

    log_p_x = autoencoder.log_normal(
        x,
        torch.from_numpy(np.array(0., dtype=np.float32)),
        torch.from_numpy(np.array(0., dtype=np.float32)),
    ).sum(1)

    x_, logdet = ss.forward_and_invlogdet(x, ctx)
    w, b = torch.chunk(ss.transform(ctx), 2, dim=-1)
    log_p_x_ = autoencoder.log_normal(
        x_, b, w * 2,
    ).sum(1)

    print(log_p_x)
    print(log_p_x_)
    print(log_p_x + logdet)
    iaf = IAF(5, 5, 4, block_size=4)
    print(iaf)
    print(iaf.forward_and_invlogdet(x, torch.ones_like(x)))

