import math
import torch
import torch.nn as nn
from gae2 import DAE, RCAE, initialize_linear_layer_weight
from Pn_util import *

class GDAE_P_n(DAE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1, log_approx = 1, corrupt_approx = False):
        super(GDAE_P_n, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
        self.exp_approx = exp_approx
        self.log_approx = log_approx
        self.corrupt_approx = corrupt_approx
        self.mat_dim = vecdim2matdim(self.input_dim)
        self.Eye = torch.eye(self.mat_dim).cuda().view(1, self.mat_dim, self.mat_dim)
        self.eye = mat2vec(self.Eye)
        noise_coeff_mat = torch.cuda.FloatTensor(self.mat_dim, self.mat_dim).fill_(1 / math.sqrt(2))
        noise_coeff_mat.fill_diagonal_(1)
        self.noise_coeff = mat2vec(noise_coeff_mat.view(1, self.mat_dim, self.mat_dim))
        
    def forward(self, x, X, X_sqrt, fixed_noise = None, other_quantities_for_loss_at_x = None):
        # assume X_sqrt is symmetric
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        if fixed_noise is None:
            epsilon = torch.cuda.FloatTensor(x.shape).normal_(0.0, self.noise_std) * self.noise_coeff
        else:
            epsilon = fixed_noise * self.noise_coeff
        Exp_epsilon = Exp_vec_approx(epsilon, approx = self.exp_approx, Eye = self.Eye, safe_backward = False)
        X_tilde = torch.bmm(torch.bmm(X_sqrt, Exp_epsilon), X_sqrt)
        
        return self.forward_autoencoder(mat2vec(X_tilde)), X_tilde
    
    def get_reconstruction(self, x, X_sqrt):
        v = self.forward_autoencoder(x)
        Exp_v = Exp_vec(v)
        
        return mat2vec(torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt))
    
    def calculate_loss(self, x, X, X_sqrt, X_invsqrt, fixed_noise = None, duplicate_num = None, other_quantities_for_loss_at_x = None):
        if duplicate_num is None or fixed_noise is not None:
            v, X_tilde = self.forward(x, X, X_sqrt, fixed_noise = fixed_noise, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
        else:
            x = x.repeat(duplicate_num,1)
            X_sqrt = X_sqrt.repeat(duplicate_num,1,1)
            X_invsqrt = X_invsqrt.repeat(duplicate_num,1,1)
            # TO DO: require modification for below...
            v, X_tilde = self.forward(x, X, X_sqrt, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
        
        # for very small v, this line may fall in infinite loop if self.exp_approx is None and safe_backward = True
        Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye, safe_backward = True)
        
        if self.corrupt_approx:
            S = other_quantities_for_loss_at_x[0]
            U = other_quantities_for_loss_at_x[1]
            Xdot = X_tilde - X
            X_tilde_sqrt = X_sqrt + get_sqrt_sym_DirDeriv(X, Xdot, eps = 1e-14, S = S, U = U)
        else:
            X_tilde_sqrt = get_sqrt_sym(X_tilde)
        Temp = torch.bmm(X_tilde_sqrt, X_invsqrt)
        dX = torch.bmm(torch.bmm(Temp.permute(0,2,1), Exp_v), Temp)
        if duplicate_num is None or fixed_noise is not None:
            return Log_mat_FnormSq_approx(dX, approx = self.log_approx, Eye = self.Eye, safe_backward = True)
        return Log_mat_FnormSq_approx(dX, approx = self.log_approx, Eye = self.Eye, safe_backward = True) / duplicate_num
        
    
    def clean_forward(self, x):
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        return self.forward_autoencoder(x)
    
    def estimate_score(self, x, X_sqrt, metric):
        v = self.forward_autoencoder(x)
        log_x_r = mat2vec(torch.bmm(torch.bmm(X_sqrt, vec2mat(v)), X_sqrt))
        
        return torch.bmm(log_x_r.view(-1,1,self.input_dim), metric).view(-1,self.input_dim) / self.noise_std**2
    
    def calculate_expected_loss(self, x, X, X_sqrt, X_invsqrt, Niter, other_quantities_for_loss_at_x = None):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                v, X_tilde = self.forward(x, X, X_sqrt, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
                Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye, safe_backward = False)

                X_tilde_sqrt = get_sqrt_sym(X_tilde)
                Temp = torch.bmm(X_tilde_sqrt, X_invsqrt)
                dX = torch.bmm(torch.bmm(Temp.permute(0,2,1), Exp_v), Temp)
                lossSum += Log_mat_FnormSq_approx(dX, approx = self.log_approx, Eye = self.Eye, safe_backward = False) / N
        
        return lossSum / Niter
    
    def get_autoencoder_derivative(self, x, create_graph=False, other_quantities_at_x = None):
        ### calculate dv/dx value
        def unit_vectors(N, length):
            result = []
            for i in range(0, length):
                x = torch.zeros(N, length).cuda()
                x[:,i] = 1
                result.append(x)
            return result

        if not x.requires_grad:
            x.requires_grad = True
        y = self.forward_autoencoder(x)
        if create_graph:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, create_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        else:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, retain_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        jacobian = torch.stack(result, dim=1)
        return jacobian
    
    def get_autoencoder_derivative_old(self, x, create_graph=False, other_quantities_at_x = None):
        ### calculate dv/dx value
        if not x.requires_grad:
            x.requires_grad = True
        y = self.forward_autoencoder(x)
        grad_set = []
        x.grad=None
        for i in range(self.input_dim):
            if x.is_cuda:
                temp = torch.zeros(y.size()).cuda()
            else:
                temp = torch.zeros(y.size())
            temp[:,i] = 1
            if create_graph:
                y.backward(gradient=temp, create_graph=True)
            else:
                y.backward(gradient=temp, retain_graph=True)
            grad_set.append(x.grad)
            x.grad = None
            #temp[:,i] = 0
        return torch.stack(grad_set, dim=1)
    
    def get_score_covec_derivative(self, x, X_sqrt, X_sqrt_dirderiv_set = None, create_graph=False, other_quantities_at_x = None):
        v = self.forward_autoencoder(x)
        V = vec2mat(v)
        X = vec2mat(x)
        dvdx = self.get_autoencoder_derivative(x, create_graph = create_graph, other_quantities_at_x = other_quantities_at_x)
        dsdx = torch.cuda.FloatTensor(dvdx.shape).zero_()
        tempdir = torch.cuda.FloatTensor(x.shape).zero_()
        for i in range(self.input_dim):
            
            if X_sqrt_dirderiv_set is None:
                tempdir[:,i] = 1
                X_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(X, vec2mat(tempdir))
                tempdir[:,i] = 0
            else:
                X_sqrt_dirderiv_i = X_sqrt_dirderiv_set[:,:,:,i]

            temp = torch.bmm(torch.bmm(X_sqrt_dirderiv_i, V), X_sqrt)
            dsdx[:,:,i] \
            = mat2vec(
                torch.bmm(torch.bmm(X_sqrt, dvdx[:,:,i]), X_sqrt) \
                + temp + temp.permute(0,2,1)
            )
            
        if returndVdxAlso:
            return dsdx, dvdx
        return dsdx
    
    def get_derivative(self, x, X_sqrt, X_sqrt_dirderiv_set = None, create_graph=False, useApprox=False, safe_backward=False, 
                       v=None, S_V=None, U_V=None, Exp_v=None, returndVdxAlso = False, returndVdxOnly = False, other_quantities_at_x = None):
        ### calculate dr/dx value
        if v is None:
            v = self.forward_autoencoder(x)
        V = vec2mat(v)
        if useApprox and self.exp_approx in [1,2]:
            if Exp_v is None:
                Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        else:
            if S_V is None or U_V is None:
                S_V, U_V = batch_eigsym(V, safe_backward)
            if Exp_v is None:
                Exp_v = Exp_mat(V, S = S_V, U = U_V, safe_backward = safe_backward)
        X = vec2mat(x)
        dvdx = self.get_autoencoder_derivative(x, create_graph = create_graph, other_quantities_at_x = other_quantities_at_x)
        
        if returndVdxOnly:
            return dvdx
        
        drdx = torch.cuda.FloatTensor(dvdx.shape).zero_()
        tempdir = torch.cuda.FloatTensor(x.shape).zero_()
        for i in range(self.input_dim):
            if useApprox and self.exp_approx in [1,2]:
                if self.exp_approx == 1:
                    Exp_dirderiv_i = vec2mat(dvdx[:,:,i])
                else:
                    dVdx_i = vec2mat(dvdx[:,:,i])
                    VdVdx_i = torch.bmm(V, dVdx_i)
                    Exp_dirderiv_i = dVdx_i + 0.5*(VdVdx_i + VdVdx_i.permute(0,2,1))
            else:
                Exp_dirderiv_i = ExpDirDeriv(V, vec2mat(dvdx[:,:,i]), S = S_V, U = U_V, safe_backward = safe_backward)
            tempdir[:,i] = 1
            if X_sqrt_dirderiv_set is None:
                X_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(X, vec2mat(tempdir))
            else:
                X_sqrt_dirderiv_i = X_sqrt_dirderiv_set[:,:,:,i]

            temp = torch.bmm(torch.bmm(X_sqrt_dirderiv_i, Exp_v), X_sqrt)
            drdx[:,:,i] \
            = mat2vec(
                torch.bmm(torch.bmm(X_sqrt, Exp_dirderiv_i), X_sqrt) \
                + temp + temp.permute(0,2,1)
            )
            tempdir[:,i] = 0
        if returndVdxAlso:
            return drdx, dvdx
        return drdx
    
    def estimate_score_error(self, x, X_sqrt, christoffelSum, X_sqrt_dirderiv_set, create_graph=True, useApprox=True, other_quantities_at_x = None):
        N = x.shape[0]
        dim = x.shape[1]
        v = self.forward_autoencoder(x)
        V = vec2mat(v)
        log_x_r = mat2vec(torch.bmm(torch.bmm(X_sqrt,V), X_sqrt))
        term1 = torch.sum(V**2) / self.noise_std**4
        term3 = 2.*torch.sum(log_x_r*christoffelSum) / self.noise_std**2
        
        drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=create_graph, useApprox=useApprox, safe_backward=True, other_quantities_at_x = other_quantities_at_x)
        
        term2 = 2. * torch.sum(drdx * torch.eye(dim).cuda()) / self.noise_std**2
        
        return (term1 + term2 + term3)/N - 2.0 * dim / self.noise_std**2
    
    def get_contractive_term(self, x, X_sqrt, metricInv, X_sqrt_dirderiv_set = None, create_graph=False, useApprox = False, safe_backward=False, 
                             expandOutput=False, other_quantities_at_x = None):
        N = x.shape[0]
        v = self.forward_autoencoder(x)
        
        if useApprox and self.exp_approx in [1,2]:
            S_V, U_V = None, None
            Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        else:
            V = vec2mat(v)
            S_V, U_V = batch_eigsym(V, safe_backward)
            Exp_v = Exp_mat(V, S = S_V, U = U_V, safe_backward = safe_backward)
        
        R = torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
        metric_r = metric_P_n(R)
        if expandOutput:
            drdx, dvdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=create_graph, useApprox = useApprox, safe_backward=safe_backward, 
                                             v=v, S_V=S_V, U_V=U_V, Exp_v=Exp_v, 
                                             returndVdxAlso=True, other_quantities_at_x=other_quantities_at_x)
        else:
            drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=create_graph, useApprox = useApprox, safe_backward=safe_backward, 
                                       v=v, S_V=S_V, U_V=U_V, Exp_v=Exp_v, 
                                       other_quantities_at_x=other_quantities_at_x)
        mat = torch.bmm(torch.bmm(torch.bmm(drdx.permute(0,2,1), metric_r), drdx), metricInv)
        
        if expandOutput:
            return torch.sum(mat * torch.eye(self.input_dim).unsqueeze(0).cuda()).item() / N, torch.sum(v*v).item()/N, \
        torch.sum(drdx**2).item()/N, torch.sum(dvdx**2).item()/N
        return (mat * torch.eye(self.input_dim).unsqueeze(0).cuda()).sum((1,2)).data / N
    
    def contractive_loss(self, x, X_sqrt, metricInv, X_sqrt_dirderiv_set = None, useApprox=False, other_quantities_at_x = None):
        v = self.forward_autoencoder(x)
        if useApprox and self.exp_approx in [1,2]:
            Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
            drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=True, useApprox=useApprox, v=v, Exp_v=Exp_v, 
                                       other_quantities_at_x=other_quantities_at_x)
        else:
            S_V, U_V = batch_eigsym(vec2mat(v))
            Exp_v = Exp_vec(v, S = S_V, U = U_V)
            drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=True, useApprox=useApprox, v=v, S_V=S_V, U_V=U_V, Exp_v=Exp_v, 
                                       other_quantities_at_x=other_quantities_at_x)
        R = torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
        metric_r = metric_P_n(R)
        mat = torch.bmm(torch.bmm(torch.bmm(drdx.permute(0,2,1), metric_r), drdx), metricInv)
        
        return torch.sum(mat * torch.eye(self.input_dim).unsqueeze(0).cuda())*self.noise_std**2 + torch.sum(vec2mat(v)**2)
    
class GDAE_P_n_fromLog(GDAE_P_n):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1, log_approx = 1, corrupt_approx = False):
        super(GDAE_P_n_fromLog, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx, log_approx, corrupt_approx)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
    def forward(self, Log_x, X, X_sqrt, fixed_noise = None, other_quantities_for_loss_at_x = None):
        # assume X_sqrt is symmetric
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        if fixed_noise is None:
            epsilon = torch.cuda.FloatTensor(Log_x.shape).normal_(0.0, self.noise_std) * self.noise_coeff
        else:
            epsilon = fixed_noise * self.noise_coeff
        Exp_epsilon = Exp_vec_approx(epsilon, approx = self.exp_approx, Eye = self.Eye, safe_backward = False)
        X_tilde = torch.bmm(torch.bmm(X_sqrt, Exp_epsilon), X_sqrt)
        if self.corrupt_approx:
            S = other_quantities_for_loss_at_x[0]
            U = other_quantities_for_loss_at_x[1]
            Xdot = X_tilde - X
            Log_X_tilde = vec2mat(Log_x) + LogDirDeriv(X, Xdot, eps = 1e-14, S = S, U = U)
        else:
            Log_X_tilde = Log_mat(X_tilde)
        return self.forward_autoencoder(mat2vec(Log_X_tilde)), X_tilde
    
    def get_derivative(self, Log_x, X_sqrt, X_sqrt_dirderiv_set = None, create_graph=False, useApprox=False, safe_backward=False, 
                       v=None, S_V=None, U_V=None, Exp_v=None, returndVdxAlso = False, returndVdxOnly = False, other_quantities_at_x = None):
        if other_quantities_at_x is None or len(other_quantities_at_x) != 1:
            raise Exception("dLog_xdx should be given in other_quantities_at_x")
        dLog_xdx = other_quantities_at_x[0]
        ### calculate dr/dx and dv/dx value
        if v is None:
            v = self.forward_autoencoder(Log_x)
        V = vec2mat(v)
        if useApprox and self.exp_approx in [1,2]:
            if Exp_v is None:
                Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        else:
            if S_V is None or U_V is None:
                S_V, U_V = batch_eigsym(V, safe_backward)
            if Exp_v is None:
                Exp_v = Exp_mat(V, S = S_V, U = U_V, safe_backward = safe_backward)
        Log_X = vec2mat(Log_x)
        dvdLog_x = self.get_autoencoder_derivative(Log_x, create_graph = create_graph, other_quantities_at_x = other_quantities_at_x)
        dvdx = torch.bmm(dvdLog_x, dLog_xdx)
        
        if returndVdxOnly:
            return dvdx
        
        drdx = torch.cuda.FloatTensor(dvdx.shape).zero_()
        tempdir = torch.cuda.FloatTensor(Log_x.shape).zero_()
        for i in range(self.input_dim):
            if useApprox and self.exp_approx in [1,2]:
                if self.exp_approx == 1:
                    Exp_dirderiv_i = vec2mat(dvdx[:,:,i])
                else:
                    dVdx_i = vec2mat(dvdx[:,:,i])
                    VdVdx_i = torch.bmm(V, dVdx_i)
                    Exp_dirderiv_i = dVdx_i + 0.5*(VdVdx_i + VdVdx_i.permute(0,2,1))
            else:
                Exp_dirderiv_i = ExpDirDeriv(V, vec2mat(dvdx[:,:,i]), S = S_V, U = U_V, safe_backward = safe_backward)
            tempdir[:,i] = 1
            if X_sqrt_dirderiv_set is None:
                X = Exp_mat(Log_X)
                X_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(X, vec2mat(tempdir))
            else:
                X_sqrt_dirderiv_i = X_sqrt_dirderiv_set[:,:,:,i]

            temp = torch.bmm(torch.bmm(X_sqrt_dirderiv_i, Exp_v), X_sqrt)
            drdx[:,:,i] \
            = mat2vec(
                torch.bmm(torch.bmm(X_sqrt, Exp_dirderiv_i), X_sqrt) \
                + temp + temp.permute(0,2,1)
            )
            tempdir[:,i] = 0
        if returndVdxAlso:
            return drdx, dvdx
        return drdx

class GRCAE_P_n(RCAE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1):
        super(GRCAE_P_n, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
        self.exp_approx = exp_approx
        self.mat_dim = vecdim2matdim(self.input_dim)
        self.Eye = torch.eye(self.mat_dim).cuda().view(1, self.mat_dim, self.mat_dim)
        
    def forward(self, x):
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        
        return self.forward_autoencoder(x)
    
    def get_reconstruction(self, x, X_sqrt, useApprox = False):
        v = self.forward_autoencoder(x)
        
        if useApprox and self.exp_approx in [1,2]:
            Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        else:
            Exp_v = Exp_vec(v)
        return mat2vec(torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt))
    
    def calculate_loss(self, x, X_sqrt, metricInv, X_sqrt_dirderiv_set = None, other_quantities_for_loss_at_x = None, separate = False):
        v = self.forward(x)
        Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        
        # calculate the contractive term
        drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=True, useApprox=True, v=v, Exp_v=Exp_v, 
                                       other_quantities_at_x=other_quantities_for_loss_at_x)
        
        R = torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
        metric_r = metric_P_n(R)
        mat = torch.bmm(torch.bmm(torch.bmm(drdx.permute(0,2,1), metric_r), drdx), metricInv)
        if separate:
            return torch.sum(vec2mat(v)**2), torch.sum(mat * torch.eye(self.input_dim).unsqueeze(0).cuda())
        return torch.sum(mat * torch.eye(self.input_dim).unsqueeze(0).cuda())*self.noise_std**2 + torch.sum(vec2mat(v)**2)
        
    def estimate_score(self, x, X_sqrt, metric):
        v = self.forward_autoencoder(x)
        log_x_r = mat2vec(torch.bmm(torch.bmm(X_sqrt, vec2mat(v)), X_sqrt))
        
        return torch.bmm(log_x_r.view(-1,1,self.input_dim), metric).view(-1,self.input_dim) / self.noise_std**2
    
    def get_autoencoder_derivative(self, x, create_graph=False):
        ### calculate dv/dx value
        def unit_vectors(N, length):
            result = []
            for i in range(0, length):
                x = torch.zeros(N, length).cuda()
                x[:,i] = 1
                result.append(x)
            return result

        if not x.requires_grad:
            x.requires_grad = True
        y = self.forward_autoencoder(x)
        if create_graph:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, create_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        else:
            result = [torch.autograd.grad(outputs=y, inputs=x, grad_outputs=unit, retain_graph=True)[0] for unit in unit_vectors(y.size(0), y.size(1))]
        jacobian = torch.stack(result, dim=1)
        return jacobian
    
    def get_score_covec_derivative(self, x, X_sqrt, X_sqrt_dirderiv_set = None, create_graph=False, other_quantities_at_x = None):
        v = self.forward_autoencoder(x)
        V = vec2mat(v)
        X = vec2mat(x)
        dvdx = self.get_autoencoder_derivative(x, create_graph = create_graph)
        dsdx = torch.cuda.FloatTensor(dvdx.shape).zero_()
        tempdir = torch.cuda.FloatTensor(x.shape).zero_()
        for i in range(self.input_dim):
            
            if X_sqrt_dirderiv_set is None:
                tempdir[:,i] = 1
                X_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(X, vec2mat(tempdir))
                tempdir[:,i] = 0
            else:
                X_sqrt_dirderiv_i = X_sqrt_dirderiv_set[:,:,:,i]

            temp = torch.bmm(torch.bmm(X_sqrt_dirderiv_i, V), X_sqrt)
            dsdx[:,:,i] \
            = mat2vec(
                torch.bmm(torch.bmm(X_sqrt, dvdx[:,:,i]), X_sqrt) \
                + temp + temp.permute(0,2,1)
            )
            
        if returndVdxAlso:
            return dsdx, dvdx
        return dsdx
    
    def get_derivative(self, x, X_sqrt, X_sqrt_dirderiv_set = None, create_graph=False, useApprox=False, safe_backward=False, 
                       v=None, S_V=None, U_V=None, Exp_v=None, returndVdxAlso = False, returndVdxOnly = False, other_quantities_at_x = None):
        ### calculate dr/dx and dv/dx value
        if v is None:
            v = self.forward_autoencoder(x)
        V = vec2mat(v)
        if useApprox and self.exp_approx in [1,2]:
            if Exp_v is None:
                Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        else:
            if S_V is None or U_V is None:
                S_V, U_V = batch_eigsym(V, safe_backward)
            if Exp_v is None:
                Exp_v = Exp_mat(V, S = S_V, U = U_V, safe_backward = safe_backward)
        X = vec2mat(x)
        dvdx = self.get_autoencoder_derivative(x, create_graph = create_graph)
        
        if returndVdxOnly:
            return dvdx
        
        drdx = torch.cuda.FloatTensor(dvdx.shape).zero_()
        tempdir = torch.cuda.FloatTensor(x.shape).zero_()
        for i in range(self.input_dim):
            if useApprox and self.exp_approx in [1,2]:
                if self.exp_approx == 1:
                    Exp_dirderiv_i = vec2mat(dvdx[:,:,i])
                else:
                    dVdx_i = vec2mat(dvdx[:,:,i])
                    VdVdx_i = torch.bmm(V, dVdx_i)
                    Exp_dirderiv_i = dVdx_i + 0.5*(VdVdx_i + VdVdx_i.permute(0,2,1))
            else:
                Exp_dirderiv_i = ExpDirDeriv(V, vec2mat(dvdx[:,:,i]), S = S_V, U = U_V, safe_backward = safe_backward)
            tempdir[:,i] = 1
            if X_sqrt_dirderiv_set is None:
                X_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(X, vec2mat(tempdir))
            else:
                X_sqrt_dirderiv_i = X_sqrt_dirderiv_set[:,:,:,i]

            temp = torch.bmm(torch.bmm(X_sqrt_dirderiv_i, Exp_v), X_sqrt)
            drdx[:,:,i] \
            = mat2vec(
                torch.bmm(torch.bmm(X_sqrt, Exp_dirderiv_i), X_sqrt) \
                + temp + temp.permute(0,2,1)
            )
            tempdir[:,i] = 0
        if returndVdxAlso:
            return drdx, dvdx
        return drdx
    
    def estimate_score_error(self, x, X_sqrt, christoffelSum, X_sqrt_dirderiv_set, create_graph=True, useApprox=True, other_quantities_at_x = None):
        N = x.shape[0]
        dim = x.shape[1]
        v = self.forward_autoencoder(x)
        V = vec2mat(v)
        log_x_r = mat2vec(torch.bmm(torch.bmm(X_sqrt,V), X_sqrt))
        term1 = torch.sum(V**2) / self.noise_std**4
        term3 = 2.*torch.sum(log_x_r*christoffelSum) / self.noise_std**2
        
        drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=create_graph, useApprox=useApprox, safe_backward=True, other_quantities_at_x = other_quantities_at_x)
        
        term2 = 2. * torch.sum(drdx * torch.eye(dim).cuda()) / self.noise_std**2
        
        return (term1 + term2 + term3)/N - 2.0 * dim / self.noise_std**2
    
class GRCAE_P_n_fromLog(GRCAE_P_n):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1):
        super(GRCAE_P_n_fromLog, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
    def get_derivative(self, Log_x, X_sqrt, X_sqrt_dirderiv_set = None, create_graph=False, useApprox=False, safe_backward=False, 
                       v=None, S_V=None, U_V=None, Exp_v=None, returndVdxAlso = False, returndVdxOnly = False, other_quantities_at_x = None):
        if other_quantities_at_x is None or len(other_quantities_at_x) != 1:
            raise Exception("dLog_xdx should be given in other_quantities_at_x")
        dLog_xdx = other_quantities_at_x[0]
        ### calculate dr/dx and dv/dx value
        if v is None:
            v = self.forward_autoencoder(Log_x)
        V = vec2mat(v)
        if useApprox and self.exp_approx in [1,2]:
            if Exp_v is None:
                Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        else:
            if S_V is None or U_V is None:
                S_V, U_V = batch_eigsym(V, safe_backward)
            if Exp_v is None:
                Exp_v = Exp_mat(V, S = S_V, U = U_V, safe_backward = safe_backward)
        Log_X = vec2mat(Log_x)
        dvdLog_x = self.get_autoencoder_derivative(Log_x, create_graph = create_graph)
        dvdx = torch.bmm(dvdLog_x, dLog_xdx)
        
        if returndVdxOnly:
            return dvdx
        
        drdx = torch.cuda.FloatTensor(dvdx.shape).zero_()
        tempdir = torch.cuda.FloatTensor(Log_x.shape).zero_()
        for i in range(self.input_dim):
            if useApprox and self.exp_approx in [1,2]:
                if self.exp_approx == 1:
                    Exp_dirderiv_i = vec2mat(dvdx[:,:,i])
                else:
                    dVdx_i = vec2mat(dvdx[:,:,i])
                    VdVdx_i = torch.bmm(V, dVdx_i)
                    Exp_dirderiv_i = dVdx_i + 0.5*(VdVdx_i + VdVdx_i.permute(0,2,1))
            else:
                Exp_dirderiv_i = ExpDirDeriv(V, vec2mat(dvdx[:,:,i]), S = S_V, U = U_V, safe_backward = safe_backward)
            tempdir[:,i] = 1
            if X_sqrt_dirderiv_set is None:
                X = Exp_mat(Log_X)
                X_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(X, vec2mat(tempdir))
            else:
                X_sqrt_dirderiv_i = X_sqrt_dirderiv_set[:,:,:,i]

            temp = torch.bmm(torch.bmm(X_sqrt_dirderiv_i, Exp_v), X_sqrt)
            drdx[:,:,i] \
            = mat2vec(
                torch.bmm(torch.bmm(X_sqrt, Exp_dirderiv_i), X_sqrt) \
                + temp + temp.permute(0,2,1)
            )
            tempdir[:,i] = 0
        if returndVdxAlso:
            return drdx, dvdx
        return drdx    

################ for timer experiments ################    
class GDAE_P_n_fromLog_vanilla(GDAE_P_n_fromLog):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1, log_approx = 1, corrupt_approx = False):
        super(GDAE_P_n_fromLog_vanilla, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx, log_approx, corrupt_approx)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables

    def forward(self, x):
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        
        return self.forward_autoencoder(x)
    
    def calculate_loss(self, x, X_sqrt, X):
        v = self.forward(x)
        Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        R = torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
        return ((X - R)**2).sum()
    
class GDAE_P_n_fromLog_geodesiconly(GDAE_P_n_fromLog):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1, log_approx = 1, corrupt_approx = False):
        super(GDAE_P_n_fromLog_geodesiconly, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx, log_approx, corrupt_approx)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables

    def forward(self, x):
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        
        return self.forward_autoencoder(x)
    
    def calculate_loss(self, x):
        v = self.forward(x)
        return torch.sum(vec2mat(v)**2)


class GDAE_P_n_fromLog_expmaponly(GDAE_P_n_fromLog):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1, log_approx = 1, corrupt_approx = False):
        super(GDAE_P_n_fromLog_expmaponly, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx, log_approx, corrupt_approx)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
    def calculate_loss(self, x, X_sqrt, X, fixed_noise = None, duplicate_num = None, other_quantities_for_loss_at_x = None):
        if duplicate_num is None or fixed_noise is not None:
            v, X_tilde = self.forward(x, X, X_sqrt, fixed_noise = fixed_noise, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
        else:
            x = x.repeat(duplicate_num,1)
            X_sqrt = X_sqrt.repeat(duplicate_num,1,1)
            X_invsqrt = X_invsqrt.repeat(duplicate_num,1,1)
            # TO DO: require modification for below...
            v, X_tilde = self.forward(x, X, X_sqrt, other_quantities_for_loss_at_x = other_quantities_for_loss_at_x)
        
        Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        R = torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
        if duplicate_num is None or fixed_noise is not None:
            return ((X - R)**2).sum()
        return ((X - R)**2).sum() / duplicate_num
    
class GRCAE_P_n_fromLog_contractiveonly(GRCAE_P_n_fromLog):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', exp_approx = 1):
        super(GRCAE_P_n_fromLog_contractiveonly, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
    def calculate_loss(self, x, X_sqrt, X, metricInv, X_sqrt_dirderiv_set = None, other_quantities_for_loss_at_x = None, separate = False):
        v = self.forward(x)
        Exp_v = Exp_vec_approx(v, approx = self.exp_approx, Eye = self.Eye)
        
        # calculate the contractive term
        drdx = self.get_derivative(x, X_sqrt, X_sqrt_dirderiv_set, create_graph=True, useApprox=True, v=v, Exp_v=Exp_v, 
                                       other_quantities_at_x=other_quantities_for_loss_at_x)
        
        R = torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
        metric_r = metric_P_n(R)
        mat = torch.bmm(torch.bmm(torch.bmm(drdx.permute(0,2,1), metric_r), drdx), metricInv)
        if separate:
            return ((X - R)**2).sum(), torch.sum(mat * torch.eye(self.input_dim).unsqueeze(0).cuda())
        return torch.sum(mat * torch.eye(self.input_dim).unsqueeze(0).cuda())*self.noise_std**2 + ((X - R)**2).sum()
##############################################
        
################ previous models ################
class GDAE_P_n_RBF(GDAE_P_n):
    def __init__(self, RBF_basis_num, RBF_sigma, dim, num_hidden_layers, noise_std, useLeakyReLU = True, initial = 'default', 
                 exp_approx = 1, log_approx = 1, loss_approx = None, RBF_type = 'gaussian'):
        super(GDAE_P_n_RBF, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU, initial, exp_approx, log_approx, False)
        # for clarity, we use uppercase or the suffix '_mat' for matrix variables and lower case or the suffix '_vec' for vector variables
        
        self.RBF_basis_num = RBF_basis_num
        self.RBF_sigma = RBF_sigma
        self.RBF_type = RBF_type
        self.loss_approx = loss_approx
        
        self.RBF_centers = []
        self.RBF_centers_sqrt = []
        self.RBF_centers_invsqrt = []
                
        if useLeakyReLU:
            nonlinearity = 'leaky_relu'
        else:
            nonlinearity = 'tanh'
        
        self.initial = initial
        self.nonlinearity = nonlinearity
        
        # modify network structure here
        network_layers = list(self.autoencoder.children())[1:]
        layer = nn.Linear(RBF_basis_num, self.h_dim)
        initialize_linear_layer_weight(layer, self.initial, self.nonlinearity)
        network_layers.insert(0, layer)
        self.autoencoder = torch.nn.Sequential(*network_layers)
    
    def set_RBF_centers(self, RBF_centers, RBF_centers_sqrt = None, RBF_centers_invsqrt = None):
        if len(RBF_centers) != self.RBF_basis_num:
            raise Exception("The length of RBF_centers should be equal to RBF_basis_num")
        self.RBF_centers = RBF_centers
        self.RBF_centers_sqrt = RBF_centers_sqrt
        self.RBF_centers_invsqrt = RBF_centers_invsqrt
        if RBF_centers_sqrt is None or RBF_centers_invsqrt is None:
            self.RBF_centers_sqrt, self.RBF_centers_invsqrt = get_sqrt_sym(RBF_centers, returnInvAlso = True)
    
    def forward_RBF(self, x, eps=1e-7):
        if len(self.RBF_centers) == 0:
            raise Exception("RBF_centers should be set")
        X = vec2mat(x)
        T = torch.matmul(torch.matmul(self.RBF_centers_invsqrt.unsqueeze(0), X.unsqueeze(1)), 
                     self.RBF_centers_invsqrt.unsqueeze(0)).view(-1,X.shape[1],X.shape[1])
        S_T, U_T = batch_eigsym(T)
        S_T[S_T<eps] = eps
        logS = torch.log(S_T).view(X.shape[0], len(self.RBF_centers), -1)

        sqdistSet = (logS*logS).sum(-1)
        
        if self.RBF_type == 'gaussian':
            output = torch.exp(-0.5*sqdistSet/self.RBF_sigma**2)
        return output
    
    def get_RBF_derivative(self, x, X_sqrt, X_invsqrt, metric, eps=1e-7):
        if len(self.RBF_centers) == 0:
            raise Exception("RBF_centers should be set")
        sqdistDerivSet, sqdistSet = sqdistDeriv_P_n(vec2mat(x), self.RBF_centers, X_sqrt, X_invsqrt, metric, eps=eps, returnsqdistAlso = True)
        if self.RBF_type == 'gaussian':
            f = torch.exp(-0.5*sqdistSet/self.RBF_sigma**2)
            dfdx = -0.5*f.unsqueeze(-1)*sqdistDerivSet/self.RBF_sigma**2
        return dfdx, f
    
    def get_RBF_2nd_derivative(self, x, X_sqrt, X_invsqrt, metric, metricDeriv, eps=1e-7):
        if len(self.RBF_centers) == 0:
            raise Exception("RBF_centers should be set")
        sqdist2ndDerivSet, sqdistDerivSet, sqdistSet = sqdist2ndDeriv_P_n(vec2mat(x), self.RBF_centers, X_sqrt, X_invsqrt, metric, metricDeriv, eps=eps, returnsqdistAlso = True)
        if self.RBF_type == 'gaussian':
            f = torch.exp(-0.5*sqdistSet/self.RBF_sigma**2)
            dfdx = -0.5*f.unsqueeze(-1)*sqdistDerivSet/self.RBF_sigma**2
            d2fdx2 = f.unsqueeze(-1).unsqueeze(-1)*(0.25*torch.matmul(sqdistDerivSet.unsqueeze(-1), sqdistDerivSet.unsqueeze(-2))/self.RBF_sigma**4 - 0.5*sqdist2ndDerivSet/self.RBF_sigma**2)
        return d2fdx2, dfdx, f
    
    def forward_autoencoder(self, x):
        return self.autoencoder(self.forward_RBF(x, eps=1e-7))
    
    def forward(self, x, X, X_sqrt, fixed_noise = None, other_quantities_for_loss_at_x = None):
        if self.loss_approx is not None:
            if other_quantities_for_loss_at_x is None or len(other_quantities_for_loss_at_x) != 3:
                raise Exception("f_at_x, dfdx_at_x, and d2fdx2_at_x should be given in other_quantities_for_loss_at_x")
        # assume X_sqrt is symmetric
        # here we do not return the reconstruction but the output of the self.forward_autoencoder, 
        # which corresponds to the tangent vector at the identity
        if fixed_noise is None:
            epsilon = torch.cuda.FloatTensor(x.shape).normal_(0.0, self.noise_std) * self.noise_coeff
        else:
            epsilon = fixed_noise * self.noise_coeff
        Exp_epsilon = Exp_vec_approx(epsilon, approx = self.exp_approx, Eye = self.Eye, safe_backward = False)
        X_tilde = torch.bmm(torch.bmm(X_sqrt, Exp_epsilon), X_sqrt)
        if self.loss_approx in [1,2]:
            # approximate computation-heavy self.forward_RBF function
            f = other_quantities_for_loss_at_x[0]
            dfdx = other_quantities_for_loss_at_x[1]
            dx = mat2vec(X_tilde) - x
            if self.loss_approx == 1:
                RBF_output = f + (dfdx * dx.unsqueeze(1)).sum(-1)
            elif self.loss_approx == 2:
                d2fdx2 = other_quantities_for_loss_at_x[2]
                RBF_output = f + (dfdx * dx.unsqueeze(1)).sum(-1) \
                + 0.5 * torch.matmul(dx.unsqueeze(1).unsqueeze(-2), torch.matmul(d2fdx2, dx.unsqueeze(1).unsqueeze(-1))).squeeze(-1).squeeze(-1)
        else:
            RBF_output = self.forward_RBF(mat2vec(X_tilde), eps=1e-7)
        return self.autoencoder(RBF_output), X_tilde
    
    def get_autoencoder_derivative(self, x, create_graph=False, other_quantities_at_x = None):
        if other_quantities_at_x is None or len(other_quantities_at_x) != 2:
            raise Exception("f_at_x and dfdx_at_x should be given in other_quantities_at_x")
        ### calculate dv/dx value
        f_at_x = other_quantities_at_x[0]
        dfdx_at_x = other_quantities_at_x[1]
        if not f_at_x.requires_grad:
            f_at_x.requires_grad = True
        y = self.autoencoder(f_at_x)
        grad_set = []
        f_at_x.grad=None
        for i in range(self.input_dim):
            if f_at_x.is_cuda:
                temp = torch.zeros(y.size()).cuda()
            else:
                temp = torch.zeros(y.size())
            temp[:,i] = 1
            if create_graph:
                y.backward(gradient=temp, create_graph=True)
            else:
                y.backward(gradient=temp, retain_graph=True)
            grad_set.append(f_at_x.grad)
            f_at_x.grad = None
            #temp[:,i] = 0
        dvdf = torch.stack(grad_set, dim=1)
        return torch.matmul(dvdf, dfdx_at_x)
#################################################