import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils.parametrizations import orthogonal, weight_norm
from torch.nn.utils import weight_norm

class LRRNN(nn.Module):
    """
    Piece-wise Linear Recurrent Neural Network
    """

    def __init__(self, dim_x,dim_z,dim_u,dim_N,params):
        """
        Args:
            dim_x (int): dimensionality of the data
            dim_z (int): dimensionality of the latent space (rank)
            dim_u (int): dimensionality of the input
            dim_N (int): amount of neurons in the network
            params (dict): dictionary of parameters       
        """

        super(LRRNN, self).__init__()
        self.d_x = dim_x
        self.d_z = dim_z
        self.d_u = dim_u
        self.d_N=dim_N

        self.params = params
        self.normal = torch.distributions.Normal(0,1)

        # Initialise noise
        # ------

        # need to keep diag positive
        self.chol_cov_embed = lambda x: torch.tril(x, diagonal=-1)+ torch.diag_embed(torch.exp(x[range(x.shape[0]),range(x.shape[0])]/2))
        self.full_cov_embed= lambda x: self.chol_cov_embed(x)@(self.chol_cov_embed(x).T)
        
        # initialise the observation noise
        if params["scalar_noise_x"]=="Cov":
            self.R_x = nn.Parameter(torch.eye(self.d_x)*np.log(params['init_noise_x'])*2, requires_grad=params['train_noise_obs'])
            self.std_embed_x=  lambda x: torch.sqrt(torch.diagonal(self.full_cov_embed(x)))
        elif params["scalar_noise_x"]:
            self.R_x = nn.Parameter(torch.ones(1)*np.log(params['init_noise_x'])*2, requires_grad=params['train_noise_obs'])
            self.std_embed_x= lambda log_var: torch.exp(log_var/2).expand(self.d_x)    
            self.var_embed_x= lambda log_var: torch.exp(log_var).expand(self.d_x)            
        else:
            self.R_x = nn.Parameter(torch.ones(self.d_x)*np.log(params['init_noise_x'])*2, requires_grad=params['train_noise_obs'])
            self.std_embed_x= lambda log_var: torch.exp(log_var/2)
            self.var_embed_x= lambda log_var: torch.exp(log_var)    

        # initialise the latent noise
        if params["scalar_noise_z"]=="Cov":
            self.R_z = nn.Parameter(torch.eye(self.d_z)*np.log(params['init_noise_z'])*2, requires_grad=params['train_noise_prior'])
            self.std_embed_z=  lambda x: torch.sqrt(torch.diagonal(self.full_cov_embed(x)))
        elif params["scalar_noise_z"]:
            self.R_z = nn.Parameter(torch.ones(1)*np.log(params['init_noise_z'])*2, requires_grad=params['train_noise_prior'])
            self.std_embed_z= lambda log_var: torch.exp(log_var/2).expand(self.d_z)    
            self.var_embed_z= lambda log_var: torch.exp(log_var).expand(self.d_z)    
        else:  
            self.R_z = nn.Parameter(torch.ones(self.d_z)*np.log(params['init_noise_z'])*2, requires_grad=params['train_noise_prior'])
            self.std_embed_z= lambda log_var: torch.exp(log_var/2)
            self.var_embed_z= lambda log_var: torch.exp(log_var)
        
        #  initialise the latent noise for t = 0
        if params["scalar_noise_z_t0"]=="Cov":
            self.R_z_t0 = nn.Parameter(torch.eye(self.d_z)*np.log(params['init_noise_z'])*2, requires_grad=params['train_noise_prior_t0'])
            self.std_embed_z_t0=  lambda x: torch.sqrt(torch.diagonal(self.full_cov_embed(x)))
        elif params["scalar_noise_z_t0"]:
            self.R_z_t0 = nn.Parameter(torch.ones(1)*np.log(params['init_noise_z_t0'])*2, requires_grad=params['train_noise_prior_t0'])
            self.std_embed_z_t0= lambda log_var: torch.exp(log_var/2).expand(self.d_z)    
            self.var_embed_z_t0= lambda log_var: torch.exp(log_var).expand(self.d_z)    
        else:  
            self.R_z_t0 = nn.Parameter(torch.ones(self.d_z)*np.log(params['init_noise_z_t0'])*2, requires_grad=params['train_noise_prior_t0'])
            self.std_embed_z_t0= lambda log_var: torch.exp(log_var/2)
            self.var_embed_z_t0= lambda log_var: torch.exp(log_var)

        
        # initialise the transition step
        # ---------
        if 'clipped' in params.keys():
            if params['clipped'] and params['activation']=='relu':
                params['activation'] = 'clipped_relu'

        self.transition = Transition(self.d_z, self.d_u, self.d_N, nonlinearity=params['activation'], 
                                                    exp_par=params['exp_par'],shared_tau = params['shared_tau'],
                                                    weight_dist = params['weight_dist'],m_orth=params["orth"],
                                                    m_norm = params["m_norm"], weight_scaler=params["weight_scaler"],
                                        train_latent_bias=params["train_latent_bias"], train_neuron_bias=params["train_neuron_bias"])

        # initialise the observation ste
        # ---------

        self.readout_rates = params['readout_rates']

        # initialise the observation step, either readout from the latent states, or from the neuron activity
        if self.readout_rates=="rates":
            self.observation = Observation(self.d_N, self.d_x, train_bias=params['train_obs_bias'], train_weights=params['train_obs_weights'],identity_readout=params['identity_readout'])
        elif self.readout_rates=="currents":
            self.observation = Observation(self.d_N, self.d_x, train_bias=params['train_obs_bias'], train_weights=params['train_obs_weights'],identity_readout=params['identity_readout'])
        else:
            self.observation = Observation(self.d_z, self.d_x, train_bias=params['train_obs_bias'], train_weights=params['train_obs_weights'],identity_readout=params['identity_readout'])

        # initialise the initial state
        # ---------

        if params['initial_state']=='zero':
            self.initial_state= nn.Parameter(torch.zeros(self.d_z), requires_grad=False)
            self.get_initial_state = lambda u: self.initial_state.unsqueeze(0) +  orth_proj(self.transition.m_transform(self.transition.m),torch.einsum('Nu,Bu->BN',self.transition.Wu,u))
        elif params['initial_state']=='trainable':
            self.initial_state= nn.Parameter(torch.zeros(self.d_z), requires_grad=True)
            self.get_initial_state = lambda u: self.initial_state.unsqueeze(0) +  orth_proj(self.transition.m_transform(self.transition.m),torch.einsum('Nu,Bu->BN',self.transition.Wu,u))
        elif params['initial_state']=='bias':
            self.get_initial_state = lambda u: -self.transition.h.unsqueeze(0) +  orth_proj(self.transition.m_transform(self.transition.m),torch.einsum('Nu,Bu->BN',self.transition.Wu,u))

    def forward(self, z,noise_scale=0,grad=True,u=None):
        """forward step of the RNN, predict z one step ahead
        Args:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
            noise_scale (float): scale of the noise
        Returns:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        """
        if noise_scale>0:
            if self.params['scalar_noise_z'] == "Cov":
                cov_chol = self.chol_cov_embed(self.R_z)
                z = self.transition(z,grad=grad,u=u)+ noise_scale*torch.einsum("xz, BzTK -> BxTK",cov_chol,self.normal.sample(z.shape))
            else:
                z = self.transition(z,grad=grad,u=u)+ noise_scale*self.normal.sample(z.shape)*self.std_embed_z(self.R_z).unsqueeze(0).unsqueeze(2).unsqueeze(3)
        else:
            z = self.transition(z,grad=grad,u=u)
        return z


    def get_latent_time_series(self, time_steps=1000,cut_off=0, noise_scale=1, z0=None,u=None):
        """
        Generate a latent time series of length time_steps
        Args:
            time_steps (int): length of the latent time series
            cut_off (int): cut off the first cut_off time steps
            noise_scale (float): scale of the noise
            z0 (torch.tensor; 1 x dim_z x 1): initial latent state
        Returns:
            Z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        """
        with torch.no_grad():
            Z = []
            if z0 is None:
                z = torch.randn(1, self.d_z,1,1,device=self.R_x.device)
            else:
                if len(z0.shape)<4:
                    z = z0.to(device=self.R_x.device).reshape(1,self.d_z,1,1)
                else:
                    z = z0.to(device=self.R_x.device)
            if u is not None:
                if len(u.shape)<4:
                    u=u.unsqueeze(-1) # add particle dim
                for t in range(time_steps + cut_off):
                    z = self.forward(z,noise_scale=noise_scale,u=u[:,:,t].unsqueeze(2))
                    Z.append(z[:,:,0])
            else:
                for t in range(time_steps + cut_off):
                    z = self.forward(z,noise_scale=noise_scale)
                    Z.append(z[:,:,0])
            Z = torch.stack(Z)
            Z = Z[cut_off:]
            Z = Z.permute(1,2,0,3)
        return Z
    
    def get_rates(self,z,grad=True,u=None):
        """transform the latent states to the neuron activity"""
        R = self.transition.get_rates(z,grad=grad,u=u)
        return R
    
    def get_observation(self,z,noise_scale=0,grad=True):
        """
        Generate observations from the latent states
        Args:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
            noise_scale (float): scale of the noise
        Returns:
            X (torch.tensor; n_trials x dim_x x time_steps x k): observations
        """
        if self.readout_rates=="rates":
            R = self.get_rates(z,grad=grad)
        elif self.readout_rates=="currents":
            m = self.transition.m_transform(self.transition.m)
            if grad:
                R = torch.einsum('Nz,BzTK->BNTK',m,z)
            else:
                R = torch.einsum('Nz,BzTK->BNTK',m.detach(),z)
        else:
            R=z
        X = self.observation(R,grad=grad)
        X += noise_scale*self.normal.sample(X.shape)*self.std_embed_x(self.R_x).unsqueeze(0).unsqueeze(2).unsqueeze(3)
        return X
    
    def inv_observation(self,X,grad=True):
        
        """
        Args:
            X (torch.tensor; n_trials x dim_x x time_steps): observations
        Returns:
            z (torch.tensor; n_trials x dim_z x time_steps): latent time series
        """
        if self.readout_rates=="currents":
            B_inv = torch.linalg.pinv((self.observation.cast_B(self.observation.B)@self.transition.m_transform(self.transition.m)).T)

        else:
            B_inv = torch.linalg.pinv(self.observation.cast_B(self.observation.B))

        if grad:
            return torch.einsum('xz,bxT->bzT', (B_inv, X-self.observation.Bias.squeeze(-1)))
        else:
            return torch.einsum('xz,bxT->bzT', (B_inv.detach(), X-self.observation.Bias.squeeze(-1).detach()))


class Observation(nn.Module):
    """
    Readout from the latent states or the neuron activity
    """
    def __init__(self, dz, dx,train_bias=True, train_weights=True,identity_readout=False):
        """
        Args:
            dz (int): dimensionality of the latent space
            dx (int): dimensionality of the data
            train_bias (bool): whether to train the bias
        """
        super(Observation, self).__init__()
        self.dz = dz
        self.dx = dx

        if identity_readout:
            B = torch.zeros(self.dz, self.dx)
            B[range(self.dx),range(self.dx)] = 1
            self.B = nn.Parameter(B, requires_grad=train_weights)
            self.mask = B
            self.cast_B = lambda x: x*self.mask
        else:
            self.B = nn.Parameter(np.sqrt(2/dz)*torch.randn(self.dz, self.dx), requires_grad=train_weights)
            self.cast_B = lambda x: x
            self.mask = torch.ones(1)
        
        self.Bias = nn.Parameter(torch.zeros(1, self.dx,1,1), requires_grad=train_bias)


    def forward(self, z,grad=True):
        """
        Args:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        Returns:
            X (torch.tensor; n_trials x dim_x x time_steps x k): observations
        """
        #print(self.cast_B(self.B))
        if grad:
            return torch.einsum('zx,bzTK->bxTK', (self.cast_B(self.B), z))+self.Bias
        else:
            return torch.einsum('zx,bzTK->bxTK', (self.cast_B(self.B).detach(), z))+self.Bias.detach()



class Transition(nn.Module):
    """
    Latent dynamics of the prior
    """
    def __init__(self, dz, du, hidden_dim, nonlinearity, exp_par,shared_tau, weight_dist="uniform", 
                 m_orth=False, m_norm=False, weight_scaler=1,train_latent_bias=True,train_neuron_bias=True):
        """
        Args:
            dz (int): dimensionality of the latent space
            hidden_dim (int): amount of neurons in the network
            nonlinearity (str): nonlinearity of the hidden layer
            exp_par (bool): whether to use the exponential parameterisation for time constants
            shared_tau (bool): whether to have one shared time constant across the latent dimensions        
        """
        super(Transition, self).__init__()
        self.dz=dz
        self.du=du

        # nonlinearity
        if nonlinearity == "relu":
            print("using ReLU activation")
            relu = torch.nn.ReLU()
            self.nonlinearity =lambda x,h: relu(x-h)
            self.dnonlinearity = relu_derivative
        elif nonlinearity == "clipped_relu":
            print("using clipped ReLU activation")
            relu = torch.nn.ReLU()
            self.nonlinearity =lambda x,h: relu(x+h)-relu(x)
            self.dnonlinearity = clipped_relu_derivative
        elif nonlinearity == "tanh":
            print("using tanh activation")
            self.nonlinearity =  lambda x,h: torch.nn.Tanh(x-h)
            self.dnonlinearity = tanh_derivative
        elif nonlinearity == "identity":
            print("using identity activation")
            self.nonlinearity = lambda x,h: x-h
            self.dnonlinearity = lambda x: torch.ones_like(x)

        #time constants
        if shared_tau:
            if exp_par:
                self.AW = nn.Parameter(torch.log(-torch.log(torch.ones(1,1,1,1)*shared_tau)))
                self.cast_A = lambda x: torch.exp(-torch.exp(x))
            else:
                self.AW = nn.Parameter(torch.ones(1,1,1,1)*shared_tau)
                self.cast_A = lambda x: x
        else:
            if exp_par:
                self.AW = init_AW_exp_par(self.dz)
                self.cast_A = exp_par_F
            else:
                self.AW = init_AW(self.dz)
                self.cast_A = lambda x: x.unsqueeze(0).unsqueeze(2).unsqueeze(3)
        
        # bias of the neurons
        if nonlinearity == "clipped_relu":
            self.h = nn.Parameter(uniform_init1d(hidden_dim), requires_grad=train_neuron_bias)
        else:
            self.h = nn.Parameter(torch.zeros(hidden_dim), requires_grad=train_neuron_bias)

        #bias of the latents
        self.hz = nn.Parameter(torch.zeros(dz), requires_grad=train_latent_bias)

        # weights (left and right singular vectors)
        if not m_orth:
            if weight_dist=="uniform":
                self.n,self.m = initialize_Ws_uniform(dz, hidden_dim)
            elif weight_dist=="gauss":
                self.n,self.m = initialize_Ws_gauss(dz, hidden_dim)
            else:
                print("WARNING: weight distribution not implemented, using uniform")
                self.n,self.m = initialize_Ws_uniform(dz, hidden_dim)
            self.m_transform = lambda x: x

        else: 
            print("orthogonalising m")
            # Orthonormal columns
            self.m =orthogonal(nn.Linear(dz,hidden_dim,bias=False))
            self.n,_ = initialize_Ws_uniform(dz, hidden_dim)
            self.m_transform = lambda x: x.weight
        self.scaling = weight_scaler
        print("weight scaler",self.scaling)

        # Input weights
        if self.du>0:
            self.Wu = nn.Parameter(uniform_init2d(hidden_dim, self.du), requires_grad=True)
        else:
            self.Wu = torch.zeros(hidden_dim,0)
    
    def forward(self, z,u=None, grad=True):
        """
        One step forward
        Args:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        Returns:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        """
        A = self.cast_A(self.AW)
        R = self.get_rates(z,grad=grad,u=u)
        if grad:
            z= A*z +torch.einsum('zN,BNTK->BzTK',self.n*self.scaling,R)+self.hz.unsqueeze(0).unsqueeze(2).unsqueeze(3)
        else:
            z= A.detach()*z +torch.einsum('zN,BNTK->BzTK',self.n.detach()*self.scaling,R)+self.hz.unsqueeze(0).unsqueeze(2).unsqueeze(3).detach()
 
        return z
    
    def get_rates(self,z,grad=True,u=None):
        """Transform latents to neuron activity
        Args:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        Returns:
            R (torch.tensor; n_trials x dim_N x time_steps x k): neuron activity"""
        m = self.m_transform(self.m)
        if grad:
            X = torch.einsum('Nz,BzTK->BNTK',m,z)
            if u is not None:
                X += torch.einsum('Nu,BuTK->BNTK',self.Wu,u)
            R = self.nonlinearity(X,self.h.unsqueeze(0).unsqueeze(2).unsqueeze(3))
        else:
            X = torch.einsum('Nz,BzTK->BNTK',m.detach(),z)
            if u is not None:
                X += torch.einsum('Nu,BuTK->BNTK',self.Wu.detach(),u)
            R = self.nonlinearity(X,self.h.unsqueeze(0).unsqueeze(2).unsqueeze(3))
        return R
    
    def jacobian(self, z):
        """Get jacobian along trajectory
        Args:
            z (torch.tensor; n_trials x dim_z x time_steps x k): latent time series
        Returns:
            jacobian (torch.tensor; n_trials x dim_z x dim_z x time_steps): jacobian along the trajectory
        """
        z=z.squeeze(-1)
        A= self.cast_A(self.AW).squeeze(-1)
        A = diag_mid(A)
        m = self.m_transform(self.m)
        X = torch.einsum('Nz,BzT->BNT',m,z)
        derivatives_act = self.dnonlinearity(X,self.h.unsqueeze(0).unsqueeze(2))
        proj_left = self.n.unsqueeze(0).unsqueeze(-1)*derivatives_act.unsqueeze(1)
        jacobian = A+ torch.einsum('BzNT,Nx->BzxT',proj_left,m)
        return jacobian
   

def init_AW(dz):
    """Talathi & Vartak 2016: Improving Performance of Recurrent Neural Network with ReLU Nonlinearity"""

    matrix_random = torch.randn(dz, dz)
    matrix_positive_normal = (1 / (dz)) * matrix_random.T @ matrix_random
    matrix = torch.eye(dz) + matrix_positive_normal
    max_ev = torch.max(abs(torch.linalg.eigvals(matrix)))
    matrix_spectral_norm_one = matrix / max_ev
    A = matrix_spectral_norm_one[range(dz),range(dz)]
    return nn.Parameter(A, requires_grad=True)
    
def init_AW_exp_par(dz):
    """exp param of Talathi & Vartak 2016: Improving Performance of Recurrent Neural Network with ReLU Nonlinearity"""
    matrix_random = torch.randn(dz, dz)
    matrix_positive_normal = 1 / (dz * dz) * matrix_random @ matrix_random.T
    matrix = torch.eye(dz) + matrix_positive_normal
    max_ev = torch.max(abs(torch.linalg.eigvals(matrix)))
    matrix_spectral_norm_one = matrix / max_ev
    A = torch.log(-torch.log(matrix_spectral_norm_one[range(dz),range(dz)]))
    #A = torch.ones(dz)*0.6
    #A=torch.log(-torch.log(A))
    return nn.Parameter(A, requires_grad=True)

def initialize_Ws_uniform(dz, N):
    """Initialize the weights of the network
        n: Uniform between -1/sqrt(N) and 1/sqrt(N)
        m: Uniform between -1/sqrt(dz) and 1/sqrt(dz)
    """
    print("using uniform init")
    n = uniform_init2d(dz, N)
    m = uniform_init2d(N, dz)
    return nn.Parameter(n, requires_grad=True),nn.Parameter(m, requires_grad=True)


def initialize_Ws_gauss(dz,N,scaling):
    """Initialize the weights of the network
        covariance between n_i and m_i of .6 (this will be adjusted by scaling)
        n: gauss with sd 1/sqrt(3 N)
        m: gauss with sd 1/sqrt(3 dz)
    """
    print("using gauss init")
    cov = torch.eye(dz*2)
    for i in range(dz):
        cov[i, dz + i] = 0.6
        cov[dz + i, i] = 0.6
    chol_cov = torch.linalg.cholesky(cov)
    loadings = chol_cov @torch.randn(dz*2, N)
    n=loadings[:dz,:]/(scaling*np.sqrt(3*N))
    m=loadings[dz:,:]/np.sqrt(3*dz)
    return nn.Parameter(n, requires_grad=True),nn.Parameter(m.T, requires_grad=True)

def uniform_init2d(dim1,dim2):
    """Uniform init between -1/sqrt(dim2) and 1/sqrt(dim2)"""
    r = 1 / np.sqrt(dim2)
    return (r*2*torch.rand(dim1,dim2)) - r

def uniform_init1d(dim1):
    """Uniform init between -1/sqrt(dim1) and 1/sqrt(dim1)"""
    r= 1 / np.sqrt(dim1)
    return (r*2*torch.rand(dim1))-r

def exp_par_F(A):
    """Exponential parameterisation of the time constants"""
    return torch.exp(-torch.exp(A)).unsqueeze(0).unsqueeze(2).unsqueeze(3)


def diag_mid(x):
    """make BXXT matrix from BXT by diagonalising."""
    return torch.diag_embed(x.permute(0,2,1)).permute(0,2,3,1)



def clamp_norm(w,max_norm=1,dim=0):
    with torch.no_grad():
        norm = w.norm(2, dim=dim, keepdim=True).clamp(min=max_norm / 2)
        desired = torch.clamp(norm, max=max_norm)
        w *= (desired / norm)
    return w

def orthogonalize(A):
    """Find close orthogonal matrix, by dropping diagonal S in SVD decomposition (i.e. set EV to 1)"""
    U, S, V = torch.svd(A)
    return U @ V.T


def split_diag_offdiag(A):
    diag = torch.diag(torch.diag(A))
    off_diag = A - diag
    return diag, off_diag


def drelu_dx(x):
    return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))

def relu_derivative(x,h):
    return drelu_dx(x-h)

def clipped_relu_derivative(x,h):
    return drelu_dx(x+h) - drelu_dx(x)

def tanh_derivative(x,h):
    return 1 - torch.tanh(x-h) ** 2



def gram_schmidt(X):
    orth_basis = []
    orth_basis.append(X[:,0])
    for i in range(1,X.shape[1]):
        substr = torch.zeros_like(X[:,i])
        for vect in orth_basis:
            overlap = torch.inner(X[:,i],vect)/torch.inner(vect,vect)
            substr+=overlap*vect
        orth_basis.append(X[:,i]-substr)
    Y=torch.stack(orth_basis)
    return Y.T
              
def orth_proj(m,x):
    """Orthogonal projection of x onto m
    m: NxZ
    x: BxN
    """
    projection_matrix = torch.linalg.inv(m.T @ m)@m.T
    return (projection_matrix@x.T).T

class normalize_m:
    def __init__(self, m):
        self.init_norm = m
        self.norm = torch.norm(m,dim=0,keepdim=True)
    def __call__(self,X):
        orth_basis = []
        orth_basis.append(X[:,0])

        for i in range(1,X.shape[1]):
            substr = torch.zeros_like(X[:,i])
            for vect in orth_basis:
                overlap = torch.inner(X[:,i],vect)/torch.inner(vect,vect)
                substr+=overlap*vect
            orth_basis.append(X[:,i]-substr)
        Y=torch.stack(orth_basis)
        #print(torch.norm(Y,dim=1))

        Y = Y/torch.norm(Y,dim=1,keepdim=True)*self.norm.T
        #print(torch.norm(Y,dim=1))
        return Y.T
              
