import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MaskedLinear(nn.Module):
    def __init__(self, mask, W, b):
        super().__init__()
        self.register_buffer("mask", mask)
        self.W = nn.Parameter(W)
        self.b = nn.Parameter(b)
    
    def forward(self, z):
        return F.linear(z, self.mask*self.W, self.b)


class MADE(nn.Module):
    '''
    Masked Autoencoder Distribution Estimator (MADE) with Gaussian conditionals
        Germain et al., 2015, ICML
        MADE: Masked Autoencoder for Distribution Estimation
        https://arxiv.org/abs/1502.03509
    '''
    def __init__(self, dim, masks, W, b, activation):
        super().__init__()
        modules = []
        for j in range(len(masks)-1):
            modules.append(MaskedLinear(masks[j], W[j], b[j]))
            modules.append(activation())
        self.hidden = nn.Sequential(*modules)
        self.m = MaskedLinear(masks[-1], W[-1][:dim], b[-1][:dim])
        self.s = MaskedLinear(masks[-1], W[-1][dim:], b[-1][dim:])
    
    def forward(self, z):
        hidden = self.hidden(z)
        m = self.m(hidden)
        s = self.s(hidden)
        return m, s


class InverseAutoregressive(nn.Module):
    '''
    Building Block of Inverse Autoregressive Flow
        Kingma et al., 2016, NIPS
        https://arxiv.org/abs/1606.04934
    '''
    def __init__(self, dim, masks, W, b, activation):
        super().__init__()
        self.made    = MADE(dim, masks, W, b, activation)
        self.sigmoid = nn.Sigmoid()
        self.logsigm = nn.LogSigmoid()
        
    def forward(self, z):
        m, s = self.made(z)
        sd = self.sigmoid(s)
        fz = sd * z + (1-sd) * m
        log_jac_det = self.logsigm(s).sum(-1)    # sd.log().sum(-1)
        return fz, log_jac_det


class AutoregressiveFlow(nn.Module):
    '''
    Input
        dim          (int): input dimension
        num_layers   (int): number of layers of the flow
        num_hidden_units
                    (list): numbers of hidden units for each layer of each MADE
        transform  (class): the building block of the Autoregressive Flow
        activation (class): the activation used in MADE
        shuffle_z   (bool): shuffle input of each MADE if True
                            otherwise, use original order for the first layer
                            and reverse the order between every other layer
        random_deg  (bool): generate random degrees for each MADE if True
                            otherwise, use sequential degree
        replicate    (int): replicate ID, used to control initial parameters
        seed_init: seed for numpy random generator, used to
                   initialize parameters, shuffle input and generate degrees
    '''
    def __init__(self, dim, num_layers, num_hidden_units, transform, 
                 activation=nn.ReLU, shuffle_z=False, random_deg=False, 
                 replicate=0, seed_init=235711131719):
        super().__init__()
        self.dim              = dim
        self.num_layers       = num_layers
        self.num_hidden_units = num_hidden_units
        self.shuffle_z        = shuffle_z
        self.random_deg       = random_deg
        self.replicate        = replicate
        
        # set the random state for the j-th replicate 
        # we use a loose upper bound to ensure no overlapping of the draws
        temp = [dim] + num_hidden_units + [dim*2]
        params_per_layer = [n0*n1+n1 for n0, n1 in zip(temp[:-1], temp[1:])]
        upper_per_replicate = sum(params_per_layer) * num_layers * 1000
        rng = np.random.default_rng(seed_init)
        rng.bit_generator.advance(upper_per_replicate * replicate)
        
        # initialize parameters for each flow layer
        W, b = self.initialize_param(rng)
        
        # create masks for each flow layer
        deg   = self.create_degrees(rng)
        masks = self.create_masks(deg)
        
        self.flow = nn.ModuleList([
            transform(dim, masks[k], W[k], b[k], activation) 
            for k in range(num_layers)
        ])
    
    def forward(self, z):
        log_jac_det = 0.
        for layer in self.flow:
            z, log_jac_det_i = layer(z)
            log_jac_det = log_jac_det + log_jac_det_i  # avoid in-place op
        return z, log_jac_det

    def initialize_param(self, rng):
        # initialize parameters for each flow layer
        dtype = torch.get_default_dtype()
        temp = [self.dim] + self.num_hidden_units + [self.dim*2]
        W = []
        b = []
        for k in range(self.num_layers):
            W_k = []
            b_k = []
            for n0, n1 in zip(temp[:-1], temp[1:]):
                bound = 1. / n0**0.5
                W_numpy = rng.uniform(-bound, bound, size=(n1, n0))
                b_numpy = rng.uniform(-bound, bound, size=(n1,))
                W_k.append( torch.tensor(W_numpy, dtype=dtype) )
                b_k.append( torch.tensor(b_numpy, dtype=dtype) )
            W.append(W_k)
            b.append(b_k)
        return W, b
        
    def create_degrees(self, rng):        
        # create degrees for each flow layer
        deg = []
        for k in range(self.num_layers):
            deg_k = []
            
            # degrees of input units
            if k % 2 == 0:
                deg_k_0 = np.arange(1, self.dim + 1)
            else:
                deg_k_0 = np.arange(self.dim, 0, -1)
            
            if self.shuffle_z:
                rng.shuffle(deg_k_0)
            deg_k.append(deg_k_0)
            
            # degrees of units for each hidden layer
            if self.random_deg:
                for N in self.num_hidden_units:
                    deg_k_l = rng.integers(deg_k[-1].min(), self.dim, N)
                    deg_k.append(deg_k_l)
            else: # sequential degree
                for N in self.num_hidden_units:
                    deg_k_l = np.arange(N) % (self.dim - 1) + 1   # dim >= 2
                    deg_k.append(deg_k_l)
            deg.append(deg_k)
        return deg
    
    def create_masks(self, deg):
        # create masks for each flow layer
        dtype = torch.get_default_dtype()
        masks = []
        for k in range(self.num_layers):
            masks_k = []
            deg_k = deg[k]
            for d0, d1 in zip(deg_k[:-1], deg_k[1:]):
                M = d0 <= d1[:, np.newaxis]
                M = torch.tensor(M, dtype=dtype)
                masks_k.append(M)
            M = deg_k[0][:, np.newaxis] > deg_k[-1]
            M = torch.tensor(M, dtype=dtype)
            masks_k.append(M)
            masks.append(masks_k)
        return masks


