import numpy as np
import torch
import torch.distributions as distributions
import torch.nn as nn
import torch.nn.functional as F


## Flow layers


class Conv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv = nn.utils.weight_norm(nn.Conv2d(*args, **kwargs))

    def forward(self, *args, **kwargs):
        return self.conv(*args, **kwargs)


class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(4, dim), nn.LeakyReLU(), Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.GroupNorm(4, dim), nn.LeakyReLU(), Conv2d(dim, dim, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Resnet(nn.Module):
    def __init__(self, d_in, d_hid, d_out, num_blocks):
        assert num_blocks > 0
        self.num_blocks = num_blocks
        super().__init__()

        # Initial conv + residual blocks
        self.blocks = nn.ModuleList(
            [Conv2d(d_in, d_hid, kernel_size=3, padding=1)] +
            [ResidualBlock(d_hid) for _ in range(num_blocks)])

        # Skip connections
        self.skip_blocks = nn.ModuleList([
            nn.Sequential(nn.GroupNorm(4, d_hid), nn.LeakyReLU(),
                          Conv2d(d_hid, d_hid, kernel_size=1, padding=0))
            for i in range(num_blocks + 1)])

        # Final output
        self.out_block = nn.Sequential(
            nn.GroupNorm(4, d_hid), nn.LeakyReLU(), Conv2d(d_hid, d_out, kernel_size=1, padding=0))

    def forward(self, x):
        skip = 0
        for block, skip_block in zip(self.blocks, self.skip_blocks):
            x = block(x)
            skip += skip_block(x)
        x = skip
        x = self.out_block(x)
        return x


def get_mask(mask, shape):
    C, H, W = shape  # pytorch uses NCHW

    if mask.startswith('checker'):
        HH, WW = (H + 1) // 2, (W + 1) // 2
        out = torch.eye(2).repeat(C, HH, WW).view(1, C, 2 * HH, 2 * WW)[:, :, :H, :W]
    elif mask.startswith('channel'):
        assert C % 2 == 0
        out = torch.cat([torch.ones(1, C // 2, H, W), torch.zeros(1, C // 2, H, W)], 1)
    else:
        raise ValueError(f'Invalid mask {mask}')

    if mask[-1] == '1':
        out = 1 - out

    assert out.shape == torch.Size([1, C, H, W]) and out.dtype == torch.float32
    return out


class Flow(nn.Module):
    def forward(self, x, logdet, factored=None, inverse=False):
        raise NotImplementedError


class ElemwiseScale(nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape
        self.scale = nn.Parameter(torch.zeros(*shape), requires_grad=True)

    def forward(self, x):
        return x * self.scale


class CouplingLayer(Flow):
    def __init__(self, shape, d_hid, mask, d_cond=None, num_blocks=5):
        super().__init__()
        self.shape = shape
        self.d_cond = d_cond
        C = shape[0]
        if d_cond is None:
            in_channel = 2 * C
        else:
            in_channel = 2 * C + d_cond

        self.mask = nn.Parameter(get_mask(mask, shape), requires_grad=False)
        self.init_norm = nn.GroupNorm(1, C)
        self.logs_and_t = Resnet(in_channel, d_hid, 2 * C, num_blocks)
        self.elemwise_scale = ElemwiseScale(1, *shape)

    def forward(self, x, logdet, factored=None, x_cond=None, inverse=False):
        C = self.shape[0]
        xx = self.init_norm(self.mask * x)
        xx = F.leaky_relu(torch.cat([xx, -xx], dim=1))

        if self.d_cond is not None:
            raise NotImplementedError

        logs, shift = self.logs_and_t(xx).split(C, dim=1)
        assert logs.shape == shift.shape == x.shape
        logs = self.elemwise_scale(logs.tanh()) * (1 - self.mask)
        # logs = (logs.tanh() * self.scale_param) * (1 - self.mask)
        shift = shift * (1 - self.mask)

        if inverse:
            scale = (-logs).exp()
            assert torch.isfinite(scale).all(), f'inf/nan in scale during inverse'
            out = (x - shift) * scale
            logdet = logdet - logs.view(len(logs), -1).sum(dim=-1)
            return out, logdet, factored
        else:
            scale = logs.exp()
            assert torch.isfinite(scale).all(), f'inf/nan in scale'
            out = x * scale + shift
            logdet = logdet + logs.view(len(logs), -1).sum(dim=-1)
            return out, logdet, factored


class Squeeze(Flow):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def forward(self, x, logdet, factored=None, x_cond=None, inverse=False):
        k = self.factor

        if inverse:
            N, Ckk, HH, WW = x.shape
            assert Ckk % (k * k) == 0, f'channel count not divisible by {k}^2'
            C = Ckk // (k * k)
            x = x.view(N, C, k, k, HH, WW)  # N, C, k_H, k_W, H/k, W/k
            x = x.permute(0, 1, 4, 2, 5, 3)  # N, C, H/k, k_H, W/k, k_W
            x = x.contiguous().view(N, C, HH * k, WW * k)
            return x, logdet, factored

        else:
            N, C, H, W = x.shape
            assert H % k == 0 and W % k == 0, f'image size not divisible by squeezing factor {k}'
            x = x.view(N, C, H // k, k, W // k, k)  # N, C, H/k, k_H, W/k, k_W
            x = x.permute(0, 1, 3, 5, 2, 4)  # N, C, k_H, k_W, H/k, W/k
            x = x.contiguous().view(N, C * k * k, H // k, W // k)

            # Squeezing is volume-preserving, so logdet = 0
            return x, logdet, factored


class Factor(Flow):
    def forward(self, x, logdet, factored=None, x_cond=None, inverse=False):

        if inverse:
            assert factored is not None
            d = np.prod(x.shape[1:])
            z = factored[:, :d, ...].view(*x.shape)
            factored = factored[:, d:, ...]
            if factored.shape[1] == 0:
                factored = None
            out = torch.cat([x, z], dim=1)
            return out, logdet, factored
        else:
            N, C = x.shape[:2]
            assert C % 2 == 0, f'channel count not even during Factoring'
            xp, zp = x.split(C // 2, dim=1)
            zp = zp.view(N, -1)
            if factored is None:
                factored = zp
            else:
                factored = torch.cat([zp, factored], dim=1)
            return xp, logdet, factored


class LogitTransform(Flow):
    def __init__(self, eps):
        super().__init__()
        self.eps = eps

    def forward(self, x, logdet, factored=None, x_cond=None, inverse=False):
        if inverse:
            return self._postprocess(x, logdet, factored)
        else:
            return self._preprocess(x, logdet, factored)

    def _preprocess(self, x, logdet, factored):
        x = (1 - 2 * self.eps) * x + self.eps
        out = torch.log(x) - torch.log(1 - x)
        assert torch.isfinite(out).all(), 'nan/inf found during logit transformation'

        dim = np.prod(x.shape[1:])
        logdet = logdet + (dim * np.log(1 - 2 * self.eps)
                           - torch.sum(torch.log(1 - x) + torch.log(x), dim=[1, 2, 3]))
        assert logdet.shape == (len(x),)
        return out, logdet, factored

    def _postprocess(self, x, logdet, factored):
        x = torch.sigmoid(x)
        out = (x - self.eps) / (1 - 2 * self.eps)

        dim = np.prod(x.shape[1:])
        logdet = logdet - (dim * np.log(1 - 2 * self.eps)
                           - torch.sum(torch.log(1 - x) + torch.log(x), dim=[1, 2, 3]))
        return out, logdet, factored


## Models

class FlowSequential(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self._final_shape = None
        for layer in layers:
            assert isinstance(layer, Flow)
        self.layers = nn.ModuleList(layers)

    def forward(self, x, x_cond=None, inverse=False):
        xshape = x.shape
        if inverse:
            assert self._final_shape is not None
            d = np.prod(self._final_shape)
            x = x.view(len(x), -1)
            x, factored = x[:, :d], x[:, d:]
            x = x.reshape(len(x), *self._final_shape)
        else:
            factored = None

        logdet = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        for layer in self.layers[::1 - 2 * int(inverse)]:
            x, logdet, factored = layer(x, logdet, factored, x_cond=x_cond, inverse=inverse)

        if not inverse and self._final_shape is None:
            self._final_shape = x.shape[1:]

        if factored is not None:
            x = x.view(len(x), -1)
            x = torch.cat([x, factored], dim=1)
        x = x.view(*xshape)

        return x, logdet


class FlowModel(nn.Module):
    def log_prob(self, x, x_cond=None):
        raise NotImplementedError

    def log_prior(self, z):
        raise NotImplementedError

    def sample(self):
        raise NotImplementedError

    def sample_prior(self):
        raise NotImplementedError


class RealNVP(FlowModel):
    def __init__(self, *, image_shape, d_hidden, n_blocks, n_scales, logit_eps=None):
        super().__init__()

        self.image_shape = image_shape
        self.D = int(np.prod(image_shape))
        self.d_hidden = d_hidden
        self.n_blocks = n_blocks
        self.n_scales = n_scales
        self.logit_eps = logit_eps

        image_shape = list(self.image_shape)
        if logit_eps is None:
            modules = []
        else:
            modules = [LogitTransform(logit_eps)]

        for _ in range(self.n_scales - 1):
            modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker0', num_blocks=self.n_blocks))
            modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker1', num_blocks=self.n_blocks))
            modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker0', num_blocks=self.n_blocks))
            modules.append(Squeeze(2))

            assert image_shape[1] % 2 == 0 and image_shape[2] % 2 == 0
            image_shape = [image_shape[0] * 4, image_shape[1] // 2, image_shape[2] // 2]

            modules.append(CouplingLayer(image_shape, self.d_hidden, 'channel0', num_blocks=self.n_blocks))
            modules.append(CouplingLayer(image_shape, self.d_hidden, 'channel1', num_blocks=self.n_blocks))
            modules.append(CouplingLayer(image_shape, self.d_hidden, 'channel0', num_blocks=self.n_blocks))
            modules.append(Factor())

            assert image_shape[0] % 2 == 0
            image_shape = [image_shape[0] // 2, image_shape[1], image_shape[2]]

        # Final layer
        modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker0', num_blocks=self.n_blocks))
        modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker1', num_blocks=self.n_blocks))
        modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker0', num_blocks=self.n_blocks))
        modules.append(CouplingLayer(image_shape, self.d_hidden, 'checker1', num_blocks=self.n_blocks))

        self.model = FlowSequential(modules)
        self.model(torch.zeros(1, *self.image_shape))

    def forward(self, x, inverse=False):
        x = x.view(len(x), *self.image_shape)
        x, logdet = self.model(x, inverse=inverse)
        return x, logdet

    def log_prob(self, x):
        z, logdet = self(x)
        log_pz = self.log_prior(z)
        assert log_pz.shape == logdet.shape
        log_px = log_pz + logdet
        return log_px, log_pz, z, logdet

    def log_prior(self, z):
        z = z.view(len(z), -1)
        log_pdf = -0.5 * (self.D * np.log(2 * np.pi) + z.norm(dim=1) ** 2)
        return log_pdf

    def sample(self, n, temp=1.0, device=None):
        z = self.sample_prior(n, temp=temp, device=device)
        x = self(z, inverse=True)[0]
        return x

    def sample_prior(self, n, temp=1.0, device=None):
        z = torch.randn(n, *self.image_shape, device=device) * temp
        return z

