import normflows as nf
from normflows.nets import ConvNet2d
import torch
from normflows.flows.base import Flow
import torch.nn as nn
import numpy as np


def making_model_glow(input_shape, channels, hidden_channels, split_mode, L, K, scale=True):
    q0 = []
    merges = []
    flows = []
    for i in range(L):
        flows_ = []
        for j in range(K):
            flows_ += [nf.flows.GlowBlock(channels * 2 ** (L + 1 - i), hidden_channels,split_mode=split_mode, scale=scale, scale_map='sigmoid')]
        flows_ += [nf.flows.Squeeze()]
        flows += [flows_]
        if i > 0:
            merges += [nf.flows.Merge()]
            latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i), 
                            input_shape[2] // 2 ** (L - i))
        else:
            latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L, 
                            input_shape[2] // 2 ** L)
        q0 += [nf.distributions.DiagGaussian(shape = latent_shape, trainable=False)]


    # Construct flow model with the multiscale architecture
    model = nf.MultiscaleFlow(q0, flows, merges,transform=None, class_cond=False)

    return model

def making_model_nsf(input_shape,  hidden_channels, L, K):
    flow_layers = []
    for i in range(K):
        flow_layers += [nf.flows.AutoregressiveRationalQuadraticSpline(input_shape, L, hidden_channels)]
        flow_layers += [nf.flows.LULinearPermute(input_shape)]
    print(input_shape)
    base = nf.distributions.DiagGaussian([input_shape], trainable=False)
    model = nf.NormalizingFlow(base, flow_layers)
    return model

def making_model_residual(input_shape,  hidden_channels, L, K):
    q0 = []
    merges = []
    flows = []
    for i in range(L):
        flows_ = []
        for j in range(K):
            net = nf.nets.LipschitzCNN([input_shape[0]* 2 ** (L + 1 - i)]  + [hidden_channels] * (2) + [input_shape[0] * 2 ** (L + 1 - i)],
                                init_zeros=True, kernel_size=[3,1,3], lipschitz_const=0.9, max_lipschitz_iter=3)
            flows_ += [nf.flows.Residual(net, reduce_memory=True, n_exact_terms=1)]
        flows_ += [nf.flows.Squeeze()]
        flows += [flows_]
        if i > 0:
            merges += [nf.flows.Merge()]
            latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i), 
                            input_shape[2] // 2 ** (L - i))
        else:
            latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L, 
                            input_shape[2] // 2 ** L)
        q0 += [nf.distributions.DiagGaussian(shape = latent_shape, trainable=False)]


    # Construct flow model with the multiscale architecture
    model = nf.MultiscaleFlow(q0, flows, merges,transform=None, class_cond=False)

    return model



def making_model_realnvp(input_shape, hidden_channels, K):
    latent_size = hidden_channels
    b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(input_shape)])
    flows = []
    for i in range(K):
        s = nf.nets.MLP([input_shape, latent_size, input_shape], init_zeros=True, output_fn='tanh')
        t = nf.nets.MLP([input_shape, latent_size, input_shape], init_zeros=True)
        if i % 2 == 0:
            flows += [nf.flows.MaskedAffineFlow(b, t, s)]
        else:
            flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]

    q0 = nf.distributions.DiagGaussian([input_shape], trainable=False)

    # Construct flow model
    model = nf.NormalizingFlow(q0=q0, flows=flows, p=None)


    return model

def making_model_realnvp_cnn(input_shape, channels, hidden_channels, split_mode, L, K, scale=True):
    q0 = []
    merges = []
    flows = []
    for i in range(L):
        flows_ = []
        for j in range(K):
            flows_ += [RealNVPBlock(channels * 2 ** (L + 1 - i), hidden_channels,split_mode=split_mode, scale=scale)]
        flows_ += [nf.flows.Squeeze()]
        flows += [flows_]
        if i > 0:
            merges += [nf.flows.Merge()]
            latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i), 
                            input_shape[2] // 2 ** (L - i))
        else:
            latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L, 
                            input_shape[2] // 2 ** L)
        q0 += [nf.distributions.DiagGaussian(shape = latent_shape, trainable=False)]


    # Construct flow model with the multiscale architecture
    model = nf.MultiscaleFlow(q0, flows, merges,transform=None, class_cond=False)

    return model

class RealNVPBlock(Flow):
    """Glow: Generative Flow with Invertible 1×1 Convolutions, [arXiv: 1807.03039](https://arxiv.org/abs/1807.03039)

    One Block of the Glow model, comprised of

    - MaskedAffineFlow (affine coupling layer)
    - Invertible1x1Conv (dropped if there is only one channel)
    - ActNorm (first batch used for initialization)
    """

    def __init__(
        self,
        channels,
        hidden_channels,
        scale=True,
        scale_map="exp",
        split_mode="channel",
        leaky=0.0,
        init_zeros=True,
        net_actnorm=False,
    ):
        """Constructor

        Args:
          channels: Number of channels of the data
          hidden_channels: number of channels in the hidden layer of the ConvNet
          scale: Flag, whether to include scale in affine coupling layer
          scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow
          split_mode: Splitting mode, for possible values see Split class
          leaky: Leaky parameter of LeakyReLUs of ConvNet2d
          init_zeros: Flag whether to initialize last conv layer with zeros
          use_lu: Flag whether to parametrize weights through the LU decomposition in invertible 1x1 convolution layers
        """
        super().__init__()
        self.flows = nn.ModuleList([])
        # Coupling layer
        kernel_size = (3, 1, 3)
        num_param = 2 if scale else 1
        if "channel" == split_mode:
            channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,)
            channels_ += (num_param * (channels // 2),)
        elif "channel_inv" == split_mode:
            channels_ = (channels // 2,) + 2 * (hidden_channels,)
            channels_ += (num_param * ((channels + 1) // 2),)
        elif "checkerboard" in split_mode:
            channels_ = (channels,) + 2 * (hidden_channels,)
            channels_ += (num_param * channels,)
        else:
            raise NotImplementedError("Mode " + split_mode + " is not implemented.")
        param_map = nf.nets.ConvNet2d(
            channels_, kernel_size, leaky, init_zeros, actnorm=net_actnorm
        )
        self.flows += [nf.flows.AffineCouplingBlock(param_map, scale, scale_map, split_mode)]


    def forward(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_tot += log_det
        return z, log_det_tot

    def inverse(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_det_tot += log_det
        return z, log_det_tot