import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from special import ExpL


class PlanarBase(nn.Module):
    """
    A single layer of planar flow
        f(z) = z + v * tanh( w^T z + b ), where w^T v > -1
        Reparameterization is defined in the child classes

    Input
        dim: dimension of the input variable
        w, b, v_: initial values [optional]
        Note: v_ is the underlying unconstrained parameter for v
    """
    def __init__(self, dim, w, b, v_):
        super().__init__()
        bound = 1. / dim**0.5
        
        if w is not None:
            self.w = nn.Parameter(w)
        else:
            self.w = nn.Parameter(torch.empty(dim))
            nn.init.uniform_(self.w, -bound, bound)
        if b is not None:
            self.b = nn.Parameter(b)
        else:
            self.b = nn.Parameter(torch.empty(1))
            nn.init.uniform_(self.b, -bound, bound)
        if v_ is not None:
            self.v_ = nn.Parameter(v_)
        else:
            self.v_ = nn.Parameter(torch.empty(dim))
            nn.init.uniform_(self.v_, -bound, bound)
    
    @property
    def v_wv(self):
        raise NotImplementedError("Reparameterization is not implemented.")
    
    def forward(self, z):
        v, wv = self.v_wv
        hyper_level = z@self.w + self.b      # (n,)    <- (n,dim)@(dim,) + (1,)
        tanh = hyper_level.tanh()            # (n,)
        fz = z + tanh.unsqueeze(dim=1) * v   # (n,dim) <- (n,dim) + (n,1)*(dim,)
        log_jac_det = (wv * (1-tanh**2)).log1p() 
        return fz, log_jac_det


class Planar(PlanarBase):
    '''
    Reparameterization
        v = v_ + [m(w^T v_) - w^T v_] w/|w|^2
    where 
        m(x) = x if x >=0, e^x-1 if x<0
    '''
    def __init__(self, dim, w=None, b=None, v_=None): 
        super().__init__(dim, w, b, v_)
        
    @property
    def v_wv(self):
        wv_ = self.w.dot(self.v_)
        if wv_ < 0:
            wv = wv_.expm1()
            v = self.v_ + (wv - wv_)*self.w/(self.w**2).sum()
        else:
            wv = wv_
            v = self.v_
        return v, wv


class Planar0(PlanarBase):
    '''
    Reparameterization
        v = v_ + [m(w^T v_) - w^T v_] w/|w|^2
    where 
        m(x) = log(1+exp(x)) - 1
    '''
    def __init__(self, dim, w=None, b=None, v_=None): 
        super().__init__(dim, w, b, v_)    
    
    @property
    def v_wv(self):
        wv_ = self.w.dot(self.v_)
        # we use `softplus` to avoid overflow 
        #    also fewer steps in backward/forward pass -> faster
        wv = F.softplus(wv_) - 1.   # wv = wv_.exp().log1p() - 1.
        v = self.v_ + (wv - wv_)*self.w/(self.w**2).sum()
        return v, wv


class LinearLowerTri(nn.Module):
    def __init__(self, dim, W_=None, b=None):
        super().__init__()
        bound = 1. / dim**0.5
        
        if W_ is not None:
            self.W_ = nn.Parameter(W_)
        else:
            self.W_ = nn.Parameter(torch.empty((dim, dim)))
            nn.init.uniform_(self.W_, -bound, bound)
        if b is not None:
            self.b = nn.Parameter(b)
        else:
            self.b = nn.Parameter(torch.empty(dim))
            nn.init.uniform_(self.b, -bound, bound)
        self.expL = ExpL()
        
    def forward(self, z):
        # We reparameterize the diagnoal with `ExpL()`:
        #    - for x > 0, it shift up by 1 unit -> same derivative
        #    - for x < 0, it is exp(x) 
        W = self.W_.tril(-1) + self.expL(self.W_.diag()).diag()
        fz = F.linear(z, W, self.b)
        log_jac_det = W.diag().log().sum()
        return fz, log_jac_det.expand(z.size(0))


class PlanarFlow(nn.Module):
    def __init__(self, dim, num_layers, reparam, linear_layer, 
                 replicate=0, seed_init=235711131719):
        super().__init__()
        self.dim          = dim
        self.num_layers   = num_layers
        self.reparam      = reparam
        self.linear_layer = linear_layer
        self.replicate    = replicate
        self.seed_init    = seed_init
        
        # prepare initial parameters
        lin_W_, lin_b, w, b, v_  = self.initialize_param()
        
        # build flow model
        pre = []
        if linear_layer:
            pre = [LinearLowerTri(dim, lin_W_, lin_b)]
        if reparam == 'new':
            planar = [Planar(dim, w[k], b[k], v_[k]) for k in range(num_layers)]
        elif reparam == 'old':
            planar = [Planar0(dim,w[k], b[k], v_[k]) for k in range(num_layers)]
        else:
            raise Exception('Invalid Reparameterization')
        self.flow = nn.ModuleList(pre+planar)
        
    def forward(self, z):
        log_jac_det = 0.
        for layer in self.flow:
            z, log_jac_det_i = layer(z)
            log_jac_det = log_jac_det + log_jac_det_i  # avoid in-place op
        return z, log_jac_det
    
    def initialize_param(self):
        # set the random state for the j-th replicate 
        # we use a loose upper bound to ensure no overlapping of the draws
        upper_planar = (2*self.dim + 1) * self.num_layers
        upper_linear = self.dim**2 + self.dim
        upper_per_replicate = (upper_planar + upper_linear) * 1000 
        rng = np.random.default_rng(self.seed_init)
        rng.bit_generator.advance(upper_per_replicate * self.replicate)
        
        # use Unif([-1/sqrt(dim), 1/sqrt(dim)]) for all effective parameters
        #     Then | w^T v_ | < [1/sqrt(dim)]^2 * dim = 1
        bound = 1. / self.dim**0.5
        w = rng.uniform(-bound, bound, size=(self.num_layers, self.dim))
        b = rng.uniform(-bound, bound, size=(self.num_layers, 1  ))
        v = rng.uniform(-bound, bound, size=(self.num_layers, self.dim))
        
        # transform the effective parameter `v` to the underlying parameter `v_`
        #      v  = v_ + [m(wv_) - wv_] * w/w^2
        #   -> v_ = v  - [m(wv_) - wv_] * w/w^2
        wv = (w*v).sum(axis=1, keepdims=True)
        w_div_w2 = w/(w**2).sum(axis=1, keepdims=True)
        if self.reparam == 'new':
            wv_ = np.log1p(wv) * (wv < 0) + wv * (wv >= 0)
        else:
            wv_ = np.log(np.expm1(wv+1))  # y=log(1+e^x)-1 -> x=log(exp(y+1)-1)
        v_ = v - (wv - wv_) * w_div_w2
        
        # convert to tensor so that it can be used by nn.Parameter()
        dtype = torch.get_default_dtype()
        w  = torch.tensor(w,  dtype=dtype)
        b  = torch.tensor(b,  dtype=dtype)
        v_ = torch.tensor(v_, dtype=dtype)
        
        lin_W_ = None 
        lin_b  = None
        if self.linear_layer:            
            lin_W_ = rng.uniform(-bound, bound, size=(self.dim, self.dim))
            lin_W_ = torch.tensor(lin_W_, dtype=dtype)
            lin_b  = rng.uniform(-bound, bound, size=self.dim)
            lin_b  = torch.tensor(lin_b,  dtype=dtype)
        return lin_W_, lin_b, w, b, v_ 

