import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class SylvesterFlow(nn.Module):
    """Orthogonal Sylvester Flow
    All layers are defined in this class rathan than split into sub-Modules
        since we need to process the iterative orthogonalization 
        for a bacth of matrices Q in parallel
    
    The transformation of each layer is given by
        f(z) = z + Q R1 tanh( R2 Q^T z + b ), 
    where R1, R2 are upper triangular matrices with diag(R1*R2) > -1 
        and Q is a matrix with orthonormal column vectors
    
    Trainable parameters
        R1_: unconstrained matrix, shape=(L,M,M)
        R2_: unconstrained matrix, shape=(L,M,M)
        Q_ : unconstrained matrix, shape=(L,D,M)
        b  : bias vector,          shape=(L,1,M)
    
    For z with shape (B,D),
        f(z) = z + tanh( z Q R2^T + b ) R1^T Q^T
                 [(B,D) @ (D,M) @ (M,M)] @ (M,M) @ (M,D) -> (B,D)
    """
    def __init__(self, dim, num_layers, num_ortho_vecs, 
                 replicate=0, seed_init=235711131719):
        super().__init__()
        self.dim            = dim
        self.num_layers     = num_layers
        self.num_ortho_vecs = num_ortho_vecs
        self.replicate      = replicate
        self.seed_init      = seed_init
        self.Q_step_max     = 100
        self.Q_threshold    = 1e-6
        assert num_ortho_vecs <= dim
        
        R1_, R2_, Q_, b = self.initialize_param()
        self.R1_ = nn.Parameter(R1_)    # (L,M,M)
        self.R2_ = nn.Parameter(R2_)    # (L,M,M)
        self.Q_  = nn.Parameter(Q_)     # (L,D,M)
        self.b   = nn.Parameter(b)      # (L,1,M)
        self.register_buffer('eye', torch.eye(num_ortho_vecs).unsqueeze(0))

    def get_upper_tri(self, R):
        # We follow van den Berg et al. (2018) and transform the diagonal with `tanh`
        # return an upper triangular matrix with shape (L,M,M)
        return R.triu(1) + R.diagonal(dim1=1,dim2=2).tanh().diag_embed(dim1=1,dim2=2)
    
    def get_orthogonal(self, Q):
        # Iterative orthogonalization
        Qnorm = (Q**2).sum(dim=(1,2), keepdim=True).sqrt()
        Q = torch.div(Q, Qnorm)
        for step in range(self.Q_step_max):
            QTQ = torch.bmm(Q.transpose(1,2), Q) # (L,M,D) @ (L,D,M) -> (L,M,M)
            I_QTQ = self.eye - QTQ # (1,M,M) - (L,M,M) -> (L,M,M)
            
            norms = (I_QTQ**2).sum(dim=(1,2)).sqrt()  # (L,)
            norms_max = norms.max()
            if norms_max <= self.Q_threshold:
                break
            
            Q = torch.bmm(Q, self.eye + 0.5 * I_QTQ) # (L,D,M) @ (L,M,M) -> (L,D,M)
        
        if norms_max > self.Q_threshold:
            warnings.warn(
                f'Orthogonalization not complete! Final max norm = {max_norm}'
            )
        return Q
        
    def forward(self, z):
        R1 = self.get_upper_tri(self.R1_) # (L,M,M)
        R2 = self.get_upper_tri(self.R2_) # (L,M,M)
        Q  = self.get_orthogonal(self.Q_) # (L,D,M)
        
        log_jac_det = 0.
        for k in range(self.num_layers):
            R1_k = R1[k]      # (M,M)
            R2_k = R2[k]      # (M,M)
            Q_k  = Q[k]       # (D,M)
            b_k  = self.b[k]  # (1,M)
            
            tanh = (z@Q_k@R2_k.T + b_k).tanh()  # (B,D) @ (D,M) @ (M,M) -> (B,M)
            z = z + tanh @ R1_k.T @ Q_k.T       # (B,M) @ (M,M) @ (M,D) -> (B,D)
            
            temp = (1-tanh**2) * R2_k.diag() * R1_k.diag()  #(B,M)*(M,) -> (B,M)
            log_jac_det = log_jac_det + temp.log1p().sum(1) #(B,)
        return z, log_jac_det
    
    def initialize_param(self):
        # set the random state for the j-th replicate 
        # we use a loose upper bound to ensure no overlapping of the draws
        M = self.num_ortho_vecs
        num_params = (2 * M**2 + self.dim*M + M) * self.num_layers
        upper_per_replicate = num_params * 1000 
        rng = np.random.default_rng(self.seed_init)
        rng.bit_generator.advance(upper_per_replicate * self.replicate)
        
        # use Unif([-1/sqrt(dim), 1/sqrt(dim)]) for all effective parameters
        #     Then | w^T v_ | < [1/sqrt(dim)]^2 * dim = 1
        bound = 1. / self.dim**0.5
        R1_ = rng.uniform(-bound, bound, size=(self.num_layers, M, M))
        R2_ = rng.uniform(-bound, bound, size=(self.num_layers, M, M))
        Q_  = rng.uniform(-bound, bound, size=(self.num_layers, self.dim, M))
        b   = rng.uniform(-bound, bound, size=(self.num_layers, 1, M))
        
        # convert to tensor so that it can be used by nn.Parameter()
        dtype = torch.get_default_dtype()
        R1_ = torch.tensor(R1_, dtype=dtype)
        R2_ = torch.tensor(R2_, dtype=dtype)
        Q_  = torch.tensor(Q_,  dtype=dtype)
        b   = torch.tensor(b,   dtype=dtype)
        return R1_, R2_, Q_, b

