import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from data_util import vector2tensor_1dim, tensor2vector_1dim, Log2Log_vec, Log_vec2Log
from tensor_data_util import group_action, Exp, get_sqrt, get_sqrt_sym, Exp_sqrt, deltaMat_sqrt_approx, \
get_sqrt_sym_DirDeriv, ExpDirDeriv, metric_N_n

class AE(nn.Module):
    def __init__(self, dim, num_hidden_layers, useLeakyReLU = True):
        super(AE, self).__init__()
        [input_dim, h_dim] = dim
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.num_layers = num_hidden_layers
        
        network_layers = []
        if num_hidden_layers == 0:
            network_layers.append(nn.Linear(input_dim, input_dim))
        else:
            network_layers.append(nn.Linear(input_dim, h_dim))
            for i in range(num_hidden_layers - 1):
                if useLeakyReLU:
                    network_layers.append(nn.LeakyReLU(0.1, inplace=True))
                else:
                    network_layers.append(nn.Tanh())
                network_layers.append(nn.Linear(h_dim, h_dim))
            if useLeakyReLU:
                network_layers.append(nn.LeakyReLU(0.1, inplace=True))
            else:
                network_layers.append(nn.Tanh())
            network_layers.append(nn.Linear(h_dim, input_dim))
        self.autoencoder = torch.nn.Sequential(*network_layers)
        """
        ### initial guess
        for i, m in enumerate(self.autoencoder):
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()
                #m.bias.data.zero_()
        """
    def forward(self, x):
        return x + self.autoencoder(x)
    def calculate_loss(self, x):
        diff = x - self.forward(x)
        return torch.sum(diff*diff)
    def get_derivative(self, x, create_graph=False):
        ### calculate dr/dx value
        if not x.requires_grad:
            x.requires_grad = True
        y = x + self.autoencoder(x)
        grad_set = []
        
        for i in range(self.input_dim):
            if x.is_cuda:
                temp = torch.zeros(y.size()).cuda()
            else:
                temp = torch.zeros(y.size())
            temp[:,i] = 1
            if create_graph:
                y.backward(gradient=temp, create_graph=True)
            else:
                y.backward(gradient=temp, retain_graph=True)
            grad_set.append(x.grad)
            x.grad = None
            #temp[:,i] = 0
        return torch.stack(grad_set, dim=1)
    def get_autoencoder_derivative(self, x, create_graph=False):
        ### calculate dr/dx value
        if not x.requires_grad:
            x.requires_grad = True
        y = self.autoencoder(x)
        grad_set = []
        
        for i in range(self.input_dim):
            if x.is_cuda:
                temp = torch.zeros(y.size()).cuda()
            else:
                temp = torch.zeros(y.size())
            temp[:,i] = 1
            if create_graph:
                y.backward(gradient=temp, create_graph=True)
            else:
                y.backward(gradient=temp, retain_graph=True)
            grad_set.append(x.grad)
            x.grad = None
            #temp[:,i] = 0
        return torch.stack(grad_set, dim=1)
    
    
class RCAE(AE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True):
        self.noise_std = noise_std
        super(RCAE, self).__init__(dim, num_hidden_layers, useLeakyReLU)
    
    def calculate_loss(self, x, weight = None):
        recon = x + self.autoencoder(x)
        drdx = self.get_derivative(x, create_graph=True)
        diff2 = recon - x
        if weight is None:
            return torch.sum(diff2*diff2) + self.noise_std**2 * torch.sum(drdx**2)
        return torch.sum(diff2*diff2*weight) + self.noise_std**2 * torch.sum((drdx**2).sum((1,2))*weight)
    
    def estimate_score(self, x):
        if x.is_cuda and not isinstance(self.noise_std, float):
            return (self.autoencoder(x)) / self.noise_std.cuda()**2
        else:
            return (self.autoencoder(x)) / self.noise_std**2
        
class GRCAE(RCAE):
    def __init__(self, dim, num_hidden_layers, noise_std, diagonal_metric = False, 
                 metricSqrtFunc = None, useLeakyReLU = True):
        self.diagonal_metric = diagonal_metric
        self.metric_sqrt_func = metricSqrtFunc
        super(GRCAE, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU)
        
    def calculate_loss(self, x, metricInv_sqrt, metric_sqrt, weight = None):
        recon = x + self.autoencoder(x)
        drdx = self.get_derivative(x, create_graph=True)
        if self.metric_sqrt_func is not None:
            metric_sqrt = self.metric_sqrt_func(recon)
        if self.diagonal_metric:
            diff2 = (x - recon) * metric_sqrt
            drdx2 = drdx * metric_sqrt * metricInv_sqrt
        else:
            diff2 = torch.bmm((x - recon).view(x.size()[0], 1, x.size()[1]), metric_sqrt)
            drdx2 = torch.bmm(torch.bmm(metric_sqrt.permute(0,2,1), drdx), metricInv_sqrt.permute(0,2,1))
        if weight is None:
            return torch.sum(diff2*diff2) + self.noise_std**2*torch.sum(drdx2*drdx2)
        return torch.sum(diff2*diff2*weight) + self.noise_std**2*torch.sum((drdx2**2).sum((1,2))*weight)
    
    def estimate_score(self, x, metric):
        recon = x + self.autoencoder(x)
        diff = recon - x
        if self.diagonal_metric:
            score_est = (diff * metric) / self.noise_std**2
        else:
            score_est = torch.bmm(diff.view(diff.size()[0], 1, diff.size()[1]), metric) \
            / self.noise_std**2
        return score_est
        
class RCAE_DTI(RCAE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, pos_dim = 3):
        self.pos_dim = pos_dim
        self.cov_dim = int(pos_dim*(pos_dim+1)/2)
        self.dim = self.pos_dim + self.cov_dim
        super(RCAE_DTI, self).__init__(dim, num_hidden_layers, useLeakyReLU)
    
    def calculate_loss(self, x, weight = None):
        recon = x + self.autoencoder(x)
        drdx = self.get_derivative(x, create_graph=True)
        diff2 = recon - x
        if weight is None:
            return torch.sum(diff2*diff2) + self.noise_std**2 * torch.sum(drdx**2)
        return torch.sum(diff2*diff2*weight) + self.noise_std**2 * torch.sum((drdx**2).sum((1,2))*weight)
        
class GRCAE_DTI(RCAE):
    def __init__(self, dim, num_hidden_layers, noise_std, covMetricCoeff = 1.0, 
                 posMetricSqrtFunc = None, useLeakyReLU = True, pos_dim = 3):
        self.pos_metric_sqrt_func = posMetricSqrtFunc
        self.cov_metric_coeff = covMetricCoeff
        self.cov_metric_coeff_sqrt = np.sqrt(covMetricCoeff)
        self.pos_dim = pos_dim
        self.cov_dim = int(pos_dim*(pos_dim+1)/2)
        self.dim = self.pos_dim + self.cov_dim
        super(GRCAE_DTI, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU)
        
    def calculate_loss(self, x, posMetricInv_sqrt, posMetric_sqrt, weight = None):
        recon = x + self.autoencoder(x)
        drdx = self.get_derivative(x, create_graph=True)
        if self.pos_metric_sqrt_func is not None:
            posMetric_sqrt = self.pos_metric_sqrt_func(recon)
        diff = x - recon
        diff2 = torch.cat(
            (
                torch.bmm(diff[:,0:self.pos_dim].view(x.size()[0], 1, self.pos_dim), 
                          posMetric_sqrt).view(x.size()[0], self.pos_dim),
                diff[:,self.pos_dim:] *  self.cov_metric_coeff_sqrt
            ),
        1)
        drdx2 = torch.cat(
            (
                torch.cat(
                    (
                        torch.bmm(torch.bmm(posMetric_sqrt.permute(0,2,1), drdx[:,:self.pos_dim,:self.pos_dim]), 
                          posMetricInv_sqrt.permute(0,2,1)),
                        torch.bmm(drdx[:,self.pos_dim:,:self.pos_dim], posMetricInv_sqrt.permute(0,2,1)) \
                        * self.cov_metric_coeff_sqrt
                    )
                , 1),
                torch.cat(
                    (
                        torch.bmm(posMetric_sqrt.permute(0,2,1), drdx[:,:self.pos_dim,self.pos_dim:]) \
                        / self.cov_metric_coeff_sqrt,
                        drdx[:,self.pos_dim:,self.pos_dim:]
                    )
                , 1)
            )
        , 2)
        
        if weight is None:
            return torch.sum(diff2*diff2) + self.noise_std**2*torch.sum(drdx2*drdx2)
        return torch.sum(diff2*diff2*weight) + self.noise_std**2*torch.sum((drdx2**2).sum((1,2))*weight)
    
    def estimate_score(self, x, posMetric):
        recon = x + self.autoencoder(x)
        diff = recon - x
        score_est = torch.cat(
            (
                torch.bmm(diff[:,0:self.pos_dim].view(diff.size()[0], 1, self.pos_dim), 
                          posMetric).view(diff.size()[0], self.pos_dim) / self.noise_std**2,
                diff[:,self.pos_dim:] * self.cov_metric_coeff / self.noise_std**2
            ),
        1)
        return score_est
    
class DAE(AE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True):
        self.noise_std = noise_std
        super(DAE, self).__init__(dim, num_hidden_layers, useLeakyReLU)
        
        
    def forward(self, x):
        if x.is_cuda:
            if isinstance(self.noise_std, float):
                epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, 
                                   self.noise_std)
            else:
                torch.cuda.FloatTensor(x.size()).normal_(0.0, 1.0) * self.noise_std.cuda()
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, 1.0) * self.noise_std
                
        return x + epsilon + self.autoencoder(x + epsilon)
    
    def calculate_loss(self, x, fixed_noise = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x)
        else:
            recon_corrupt = self.clean_forward(x + fixed_noise)
        diff2 = recon_corrupt - x
        return torch.sum(diff2*diff2)
    
    def clean_forward(self, x):
        return x + self.autoencoder(x)
    
    def update_dir_msc(self, x):
        # msc: mean shift clustering
        if x.is_cuda and not isinstance(self.noise_std, float):
            return (self.autoencoder(x)) / self.noise_std.cuda()**2
        else:
            return (self.autoencoder(x)) / self.noise_std**2
    
    def estimate_score(self, x):
        if x.is_cuda and not isinstance(self.noise_std, float):
            return (self.autoencoder(x)) / self.noise_std.cuda()**2
        else:
            return (self.autoencoder(x)) / self.noise_std**2
    
    def calculate_expected_loss(self, x, Niter):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x)
                diff2 = recon_corrupt - x
                lossSum += torch.sum(diff2*diff2) / N
        return lossSum / Niter
        
        
class GDAE(DAE):
    def __init__(self, dim, num_hidden_layers, noise_std, diagonal_metric = False, 
                 metricSqrtFunc = None, useLeakyReLU = True):
        self.diagonal_metric = diagonal_metric
        self.metric_sqrt_func = metricSqrtFunc
        super(GDAE, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU)
        
    def forward(self, x, metricInv_sqrt):
        if self.diagonal_metric:
            if x.is_cuda:
                epsilon = torch.cuda.FloatTensor(x.size()[0],x.size()[1]).normal_(0.0, 
                                   self.noise_std) * metricInv_sqrt
            else:
                epsilon = torch.FloatTensor(x.size()[0],x.size()[1]).normal_(0.0, 
                                   self.noise_std) * metricInv_sqrt
        else:
            if x.is_cuda:
                epsilon = torch.bmm(torch.cuda.FloatTensor(x.size()[0],1,x.size()[1]).normal_(0.0, 
                                   self.noise_std), 
                                metricInv_sqrt).view(x.size()[0],x.size()[1])
                
            else:
                epsilon = torch.bmm(torch.FloatTensor(x.size()[0],1,x.size()[1]).normal_(0.0, 
                                   self.noise_std),
                                metricInv_sqrt).view(x.size()[0],x.size()[1])
        #print(epsilon.size())
        #print(metricInv_sqrt.size())
        return x + epsilon + self.autoencoder(x + epsilon)
    
    def calculate_loss(self, x, metricInv_sqrt, metric_sqrt, fixed_noise = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x, metricInv_sqrt)
        else:
            recon_corrupt = self.clean_forward(x + fixed_noise)
        if self.metric_sqrt_func is not None:
            recon = self.clean_forward(x)
            metric_sqrt = self.metric_sqrt_func(recon)
        if self.diagonal_metric:
            diff2 = (x - recon_corrupt) * metric_sqrt
        else:
            diff2 = torch.bmm((x - recon_corrupt).view(x.size()[0], 1, x.size()[1]), metric_sqrt)
        return torch.sum(diff2*diff2)
    
    def estimate_score(self, x, metric):
        recon = x + self.autoencoder(x)
        diff = recon - x
        if self.diagonal_metric:
            score_est = (diff * metric) / self.noise_std**2
        else:
            score_est = torch.bmm(diff.view(diff.size()[0], 1, diff.size()[1]), metric) \
            / self.noise_std**2
        return score_est
        
    def calculate_expected_loss(self, x, metricInv_sqrt, metric_sqrt, Niter):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x, metricInv_sqrt)
                if self.metric_sqrt_func is not None:
                    recon = self.clean_forward(x)
                    metric_sqrt = self.metric_sqrt_func(recon)
                if self.diagonal_metric:
                    diff2 = (x - recon_corrupt) * metric_sqrt
                else:
                    diff2 = torch.bmm((x - recon_corrupt).view(x.size()[0], 1, x.size()[1]), 
                                  metric_sqrt)
                lossSum += torch.sum(diff2*diff2) / N
        return lossSum / Niter
        
class DAE_DTI(DAE):
    def __init__(self, dim, num_hidden_layers, noise_std, useLeakyReLU = True, pos_dim = 3):
        self.pos_dim = pos_dim
        self.cov_dim = int(pos_dim*(pos_dim+1)/2)
        self.dim = self.pos_dim + self.cov_dim
        super(DAE_DTI, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU)
    
    def calculate_loss(self, x, fixed_noise = None, weight = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x)
        else:
            recon_corrupt = self.clean_forward(x + fixed_noise)
        diff2 = recon_corrupt - x
        if weight is None:
            return torch.sum(diff2*diff2)
        return torch.sum(diff2*diff2*weight)
    
    def calculate_expected_loss(self, x, Niter, weight = None):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x)
                diff2 = recon_corrupt - x
                if weight is None:
                    lossSum += torch.sum(diff2*diff2) / N
                else:
                    lossSum += torch.sum(diff2*diff2*weight) / N
        return lossSum / Niter
    
class GDAE_DTI(DAE):
    def __init__(self, dim, num_hidden_layers, noise_std, covMetricCoeff = 1.0,
                 posMetricSqrtFunc = None, useLeakyReLU = True, pos_dim = 3,
                approx_order = None):
        self.pos_metric_sqrt_func = posMetricSqrtFunc
        self.cov_metric_coeff = covMetricCoeff
        self.cov_metric_coeff_sqrt = np.sqrt(covMetricCoeff)
        self.pos_dim = pos_dim
        self.cov_dim = int(pos_dim*(pos_dim+1)/2)
        self.dim = self.pos_dim + self.cov_dim
        self.approx_order = approx_order
        super(GDAE_DTI, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU)
        # input shape (N, 9) (first 3: position part, last 6: covariance part)
        # or 2-dim DTI input (N, 5) (first 2: position part, last 3: covariance part)
        
    def forward(self, x, posMetricInv_sqrt):
        if x.is_cuda:
            epsilon = torch.cat( 
                (
                    torch.bmm(torch.cuda.FloatTensor(x.size()[0],1,self.pos_dim).normal_(0.0, 
                                   self.noise_std), posMetricInv_sqrt).view(x.size()[0],self.pos_dim), 
                    torch.cuda.FloatTensor(x.size()[0],self.cov_dim).normal_(0.0, 
                                   self.noise_std / self.cov_metric_coeff_sqrt)
                ), 
            1)
        else:
            epsilon = torch.cat(
                (
                    torch.bmm(torch.FloatTensor(x.size()[0],1,self.pos_dim).normal_(0.0, 
                                   self.noise_std), posMetricInv_sqrt).view(x.size()[0],self.pos_dim),
                    torch.FloatTensor(x.size()[0],self.cov_dim).normal_(0.0, 
                                   self.noise_std / self.cov_metric_coeff_sqrt)
                ), 
            1)
        
        return x + epsilon + self.autoencoder(x + epsilon)
    
    def calculate_loss(self, x, posMetricInv_sqrt, posMetric_sqrt, fixed_noise = None, weight = None,
                      expJacobian = None):
        if fixed_noise is None:
            recon_corrupt = self.forward(x, posMetricInv_sqrt)
        else:
            recon_corrupt = self.clean_forward(x + fixed_noise)
        if self.pos_metric_sqrt_func is not None:
            recon = self.clean_forward(x)
            if self.approx_order is None:
                posMetric_sqrt = self.pos_metric_sqrt_func(recon)
            elif self.approx_order == 1 and expJacobian is not None:
                a = 1
                ### TO DO ... use ExpJacobian...
                #posMetric_sqrt = (covInv_sqrt.permute(0,2,1) + \
                #torch.bmm(
                #    -0.5*vector2tensor_1dim(vec_clean[:,self.pos_dim:]), 
                #    covInv_sqrt.permute(0,2,1)
                #)).permute(0,2,1)
            
        diff = x - recon_corrupt
        diff2 = torch.cat(
            (
                torch.bmm(diff[:,0:self.pos_dim].view(x.size()[0], 1, self.pos_dim), 
                          posMetric_sqrt).view(x.size()[0], self.pos_dim),
                diff[:,self.pos_dim:] *  self.cov_metric_coeff_sqrt
            ),
        1)
        if weight is None:
            return torch.sum(diff2*diff2)
        return torch.sum(diff2*diff2*weight)
    
    def estimate_score(self, x, posMetric):
        recon = x + self.autoencoder(x)
        diff = recon - x
        score_est = torch.cat(
            (
                torch.bmm(diff[:,0:self.pos_dim].view(diff.size()[0], 1, self.pos_dim), 
                          posMetric).view(diff.size()[0], self.pos_dim) / self.noise_std**2,
                diff[:,self.pos_dim:] * self.cov_metric_coeff / self.noise_std**2
            ),
        1)
        return score_est
        
    def calculate_expected_loss(self, x, posMetricInv_sqrt, posMetric_sqrt, Niter, weight = None):
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                recon_corrupt = self.forward(x, posMetricInv_sqrt)
                if self.pos_metric_sqrt_func is not None:
                    recon = self.clean_forward(x)
                    posMetric_sqrt = self.pos_metric_sqrt_func(recon)
                diff = x - recon_corrupt
                diff2 = torch.cat(
                    (
                        torch.bmm(diff[:,0:self.pos_dim].view(x.size()[0], 1, self.pos_dim), 
                          posMetric_sqrt).view(x.size()[0], self.pos_dim),
                        diff[:,self.pos_dim:] *  self.cov_metric_coeff_sqrt
                    ),
                1)
                if weight is None:
                    lossSum += torch.sum(diff2*diff2) / N
                else:
                    lossSum += torch.sum(diff2*diff2*weight) / N
        return lossSum / Niter
    
class GDAE_N_n(DAE):
    def __init__(self, dim, num_hidden_layers, noise_std, covCoeff = 1.0,
                 posMetricFunc = None, useLeakyReLU = True, pos_dim = 3, approx_order = 1, 
                 use_logvec_input = False, use_exp_map_sqrt = False, use_exp_map_corrupt = False):
        #### only consider GPU implementation for this case
        #### pos_metric_sqrt_func only gets covariance part as input for this class
        self.pos_metric_func = posMetricFunc
        self.cov_coeff = covCoeff
        self.cov_coeff_sqrt = np.sqrt(covCoeff)
        self.pos_dim = pos_dim
        self.cov_dim = int(pos_dim*(pos_dim+1)/2)
        self.dim = self.pos_dim + self.cov_dim
        self.approx_order = approx_order
        self.use_logvec_input = use_logvec_input
        self.use_exp_map_sqrt = use_exp_map_sqrt
        self.use_exp_map_corrupt = use_exp_map_corrupt
        self.Eye = torch.eye(self.pos_dim).cuda().view(1,self.pos_dim,self.pos_dim)
        
        if self.pos_dim == 3:
            self.cov_metric_coeff_sqrt = torch.cuda.FloatTensor([1.0, np.sqrt(2.0), 
                                                           np.sqrt(2.0), 1.0, np.sqrt(2.0), 1.0]).view(1,6)
            self.cov_noise_coeff = torch.cuda.FloatTensor([1.0, 1.0/np.sqrt(2.0), 
                                                           1.0/np.sqrt(2.0), 1.0, 1.0/np.sqrt(2.0), 1.0]).view(1,6)
            self.Eye_vec = torch.cuda.FloatTensor([1.0, 0.0, 0.0, 1.0, 0.0, 1.0]).view(1,self.cov_dim)
        elif self.pos_dim == 2:
            self.cov_metric_coeff_sqrt = torch.cuda.FloatTensor([1.0, np.sqrt(2.0), 1.0]).view(1,3)
            self.cov_noise_coeff = torch.cuda.FloatTensor([1.0, 1.0/np.sqrt(2.0), 1.0]).view(1,3)
            self.Eye_vec = torch.cuda.FloatTensor([1.0, 0.0, 1.0]).view(1,self.cov_dim)
        super(GDAE_N_n, self).__init__(dim, num_hidden_layers, noise_std, useLeakyReLU)
        # input shape (N, 9) (first 3: position part, last 6: covariance part)
        # or 2-dim DTI input (N, 5) (first 2: position part, last 3: covariance part)
    
    def forward(self, x, cov_sqrt, logJacobian):
        #### caution: for the covariance part, we return tangent vector at Identity in this function!!!
        ### cov_sqrt: D^(1/2)R^T of covariance (RDR^T)
        #### cov_sqrt corresponds to posMetricInv_sqrt
        epsilon_pos = torch.bmm(torch.cuda.FloatTensor(x.size()[0],1,self.pos_dim).normal_(0.0, 
                                   self.noise_std), cov_sqrt).view(x.size()[0],self.pos_dim) \
            * self.cov_coeff_sqrt
        ### multiply sqrt(2) to consider 1/2 part in the Fisher information metric
        epsilon_cov = torch.cuda.FloatTensor(x.size()[0],self.cov_dim).normal_(0.0, 
                                   self.noise_std) * self.cov_noise_coeff * np.sqrt(2.0)
        
        x_tilde = x.clone()
        x_tilde[:,:self.pos_dim] += epsilon_pos
        ### get corrupted covariance
        if self.approx_order is None or self.use_exp_map_corrupt:
            Exp_epsilon = Exp(epsilon_cov, returnVec = False)
        elif self.approx_order == 1:
            Exp_epsilon = self.Eye_vec + epsilon_cov
        else:
            dcov = vector2tensor_1dim(epsilon_cov)
            Exp_epsilon = tensor2vector_1dim(self.Eye + dcov + 0.5*torch.bmm(dcov, dcov))
        if self.use_logvec_input:
            if self.approx_order is None:
                x_tilde[:,self.pos_dim:] = Log2Log_vec(Log(
                    group_action(Exp_epsilon, cov_sqrt.permute(0,2,1), returnVec = True)
                ))
            else:
                x_tilde[:,self.pos_dim:] += Log2Log_vec(
                    torch.bmm(
                        logJacobian, 
                        group_action(Exp_epsilon - self.Eye_vec, cov_sqrt.permute(0,2,1), 
                                     returnVec = True).view(-1,self.cov_dim,1)
                    ).view(-1,self.cov_dim)
                )
        else:
            x_tilde[:,self.pos_dim:] = group_action(Exp_epsilon, cov_sqrt.permute(0,2,1), returnVec = True)
        
        return self.autoencoder(x_tilde), x_tilde, epsilon_cov
    
    def clean_forward(self, x):
        #### caution: for the covariance part, we return tangent vector at Identity in this function!!!
        return self.autoencoder(x)
    
    def get_reconstruction(self, x, cov_sqrt, logJacobian = None):
        ### update covariance using exponential
        v = self.autoencoder(x)
        r = x.clone()
        r[:,:self.pos_dim] += v[:,:self.pos_dim]
        if self.use_logvec_input:
            if self.approx_order is None:
                r[:,self.pos_dim:] = Log2Log_vec(Log(
                    group_action(Exp(v[:,self.pos_dim:]), cov_sqrt.permute(0,2,1), returnVec = True)
                ))
            else:
                dcov = vector2tensor_1dim(v[:,self.pos_dim:])
                r[:,self.pos_dim:] += Log2Log_vec(
                    torch.bmm(
                        logJacobian, 
                        group_action(dcov + 0.5*torch.bmm(dcov, dcov), cov_sqrt.permute(0,2,1), 
                                     returnVec = True).view(-1,self.cov_dim,1)
                    ).view(-1,self.cov_dim)
                )
        else:
            r[:,self.pos_dim:] = group_action(Exp(v[:,self.pos_dim:]), cov_sqrt.permute(0,2,1), returnVec = True)
        
        return r
    
    def get_reconstruction_derivative(self, x, cov_sqrt):
        ### only consider self.use_logvec_input = False case
        if self.use_logvec_input:
            drdx = None
        else:
            vec = self.autoencoder(x)[:,self.pos_dim:]
            dvdx = self.get_autoencoder_derivative(x)
            drdx = dvdx.clone()
            for i in range(self.pos_dim):
                drdx[:,i,i] = drdx[:,i,i] + 1.0
            Exp_vec = Exp(vec, returnVec = False)
            
            for i in range(self.pos_dim):
                Exp_dirderiv_i = ExpDirDeriv(vec, dvdx[:,self.pos_dim:,i])
                drdx[:,self.pos_dim:,i] = group_action(Exp_dirderiv_i, cov_sqrt.permute(0,2,1), returnVec = True)
            
            tempdir = torch.zeros(x.shape[0], self.cov_dim).cuda()
            for i in range(self.cov_dim):
                Exp_dirderiv_i = ExpDirDeriv(vec, dvdx[:,self.pos_dim:,self.pos_dim + i])
                tempdir[:,i] = 1
                cov_sqrt_dirderiv_i = get_sqrt_sym_DirDeriv(x[:,self.pos_dim:], tempdir)
                temp = torch.bmm(torch.bmm(cov_sqrt_dirderiv_i, Exp_vec), cov_sqrt)
                drdx[:,self.pos_dim:,self.pos_dim + i] \
                = tensor2vector_1dim(
                    group_action(Exp_dirderiv_i, cov_sqrt.permute(0,2,1), returnVec = False) \
                    + temp + temp.permute(0,2,1)
                )
                tempdir[:,i] = 0
        
        return drdx
            
    def calculate_loss(self, x, cov_sqrt, covInv_sqrt, logJacobian, cov_eigvec, cov_eigval, weight = None):
        ### covInv_sqrt: RD^(-1/2) of covariance (RDR^T)
        #### covInv_sqrt corresponds to posMetric_sqrt
        vec, x_tilde, epsilon_cov = self.forward(x, cov_sqrt, logJacobian)
        
        if self.pos_metric_func is not None:
            vec_clean = self.clean_forward(x)
            if self.approx_order is None:
                posMetric = self.pos_metric_func(vec_clean[:,self.pos_dim:], covInv_sqrt)
                diff_pos_sq = torch.bmm(
                    torch.bmm(vec[:,:self.pos_dim].view(x.size()[0], 1, self.pos_dim), posMetric),
                    vec[:,:self.pos_dim].view(x.size()[0], self.pos_dim, 1)
                ).view(x.size()[0]) / self.cov_coeff
            else:
                if self.approx_order == 1:
                    posMetric_sqrt = (covInv_sqrt.permute(0,2,1) + \
                    torch.bmm(
                        -0.5*vector2tensor_1dim(vec_clean[:,self.pos_dim:]), 
                        covInv_sqrt.permute(0,2,1)
                    )).permute(0,2,1)
                else:
                    dvec_cov = vector2tensor_1dim(vec_clean[:,self.pos_dim:])
                    posMetric_sqrt = (covInv_sqrt.permute(0,2,1) + \
                    torch.bmm(
                        -0.5*dvec_cov + 0.125*torch.bmm(dvec_cov,dvec_cov), 
                        covInv_sqrt.permute(0,2,1)
                    )).permute(0,2,1)
                diff_pos_sq = torch.bmm(vec[:,:self.pos_dim].view(x.size()[0], 1, self.pos_dim), 
                          posMetric_sqrt).view(x.size()[0], self.pos_dim) / self.cov_coeff_sqrt
                diff_pos_sq = torch.sum(diff_pos_sq**2, dim = 1)
        else:
            posMetric_sqrt = covInv_sqrt
            diff_pos_sq = torch.bmm(vec[:,:self.pos_dim].view(x.size()[0], 1, self.pos_dim), 
                          posMetric_sqrt).view(x.size()[0], self.pos_dim) / self.cov_coeff_sqrt
            diff_pos_sq = torch.sum(diff_pos_sq**2, dim = 1)
        
        if self.approx_order is None or self.use_exp_map_sqrt:
            if self.use_logvec_input:
                cov_tilde_sqrt = Exp_sqrt(Log_vec2Log(x_tilde[:,self.pos_dim:]), returnVec = False)
                #cov_tilde_sqrt = get_sqrt(Exp(Log_vec2Log(x_tilde[:,self.pos_dim:])))
            else:
                cov_tilde_sqrt = get_sqrt_sym(x_tilde[:,self.pos_dim:])
                #cov_tilde_sqrt = get_sqrt(x_tilde[:,self.pos_dim:])
        else:
            if self.approx_order == 1:
                dcov = group_action(epsilon_cov, cov_sqrt.permute(0,2,1))
                #cov_tilde_sqrt = (cov_sqrt.permute(0,2,1) + \
                #torch.bmm(cov_sqrt.permute(0,2,1), 0.5*vector2tensor_1dim(epsilon_cov))).permute(0,2,1)
            else:
                dcov = vector2tensor_1dim(epsilon_cov)
                dcov = group_action(dcov + 0.5*torch.bmm(dcov, dcov), cov_sqrt.permute(0,2,1))
                #cov_tilde_sqrt = (cov_sqrt.permute(0,2,1) + \
                #torch.bmm(cov_sqrt.permute(0,2,1), 0.5*dcov + 0.125*torch.bmm(dcov, dcov))).permute(0,2,1)
            cov_tilde_sqrt = cov_sqrt + deltaMat_sqrt_approx(cov_eigvec, cov_eigval, dcov)
        
        temp = torch.bmm(cov_tilde_sqrt, covInv_sqrt).permute(0,2,1)
        if self.approx_order is None:
            Exp_vec = Exp(vec[:,self.pos_dim:])
        elif self.approx_order == 1:
            Exp_vec = self.Eye_vec + vec[:,self.pos_dim:]
        else:
            ### consider up to the second order
            dcov = vector2tensor_1dim(vec[:,self.pos_dim:])
            Exp_vec = self.Eye + dcov + 0.5*torch.bmm(dcov, dcov)
        dX = group_action(Exp_vec, temp, returnVec = False)
        diff_cov_sq = 0.5*((dX-self.Eye)**2).sum((1,2))
        #diff_cov_sq = (0.5*dX**2 - dX*torch.eye(self.pos_dim).cuda().view(1,self.pos_dim,self.pos_dim)).sum((1,2)) \
        #+ 0.5*self.pos_dim*x.shape[0]
        diff2 = diff_pos_sq + diff_cov_sq
        if weight is None:
            return torch.sum(diff2)
        return torch.sum(diff2*weight)
    
    def estimate_score(self, x, metric, cov_sqrt, logJacobian = None):
        recon = self.get_reconstruction(x, cov_sqrt, logJacobian)
        diff = recon - x
        posMetric = metric[:,:self.pos_dim,:self.pos_dim]
        covMetric = metric[:,self.pos_dim:,self.pos_dim:]
        
        # fix 21.08.08, for covariance part, we do not need to divide the values by self.cov_coeff**2...
        score_est = torch.cat(
            (
                torch.bmm(diff[:,0:self.pos_dim].view(diff.size()[0], 1, self.pos_dim), 
                          posMetric).view(diff.size()[0], self.pos_dim) / self.cov_coeff / self.noise_std**2,
                torch.bmm(diff[:,self.pos_dim:].view(diff.size()[0], 1, self.cov_dim),
                          covMetric).view(diff.size()[0], self.cov_dim) / self.noise_std**2
            ),
        1)
        """
        score_est = torch.cat(
            (
                torch.bmm(diff[:,0:self.pos_dim].view(diff.size()[0], 1, self.pos_dim), 
                          posMetric).view(diff.size()[0], self.pos_dim) / self.cov_coeff / self.noise_std**2,
                torch.bmm(diff[:,self.pos_dim:].view(diff.size()[0], 1, self.cov_dim),
                          covMetric).view(diff.size()[0], self.cov_dim) / self.cov_coeff**2 / self.noise_std**2
            ),
        1)
        """
        return score_est
        
    def calculate_expected_loss(self, x, cov_sqrt, covInv_sqrt, logJacobian, Niter, approx_order = 1, weight = None):
        ### fix later...?
        lossSum = 0.0
        N = x.size()[0]
        with torch.no_grad():
            for i in range(Niter):
                vec, x_tilde = self.forward(x, cov_sqrt, logJacobian)
                if self.pos_metric_func is not None:
                    vec_clean = self.clean_forward(x)
                    posMetric = self.pos_metric_func(vec_clean[:,self.pos_dim:], covInv_sqrt)
                    diff_pos_sq = torch.bmm(
                    torch.bmm(vec[:,:self.pos_dim].view(x.size()[0], 1, self.pos_dim), posMetric),
                            vec[:,:self.pos_dim].view(x.size()[0], self.pos_dim, 1)
                    ).view(x.size()[0]) / self.cov_coeff
                else:
                    posMetric_sqrt = covInv_sqrt
                    diff_pos_sq = torch.bmm(vec[:,:self.pos_dim].view(x.size()[0], 1, self.pos_dim), 
                                  posMetric_sqrt).view(x.size()[0], self.pos_dim) / self.cov_coeff_sqrt
                    diff_pos_sq = torch.sum(diff_pos_sq**2, dim = 1)
                if self.approx_order is None:
                    if self.use_logvec_input:
                        cov_tilde_sqrt = get_sqrt(Exp(Log_vec2Log(x_tilde[:,self.pos_dim:])))
                    else:
                        cov_tilde_sqrt = get_sqrt(x_tilde[:,self.pos_dim:])
                elif self.approx_order == 1:
                    cov_tilde_sqrt = (cov_sqrt.permute(0,2,1) + \
                    torch.bmm(cov_sqrt.permute(0,2,1), 0.5*vector2tensor_1dim(epsilon_cov))).permute(0,2,1)
                else:
                    dcov = vector2tensor_1dim(epsilon_cov)
                    cov_tilde_sqrt = (cov_sqrt.permute(0,2,1) + \
                    torch.bmm(cov_sqrt.permute(0,2,1), 0.5*dcov + 0.125*torch.bmm(dcov, dcov))).permute(0,2,1)
                    
                temp = torch.bmm(cov_tilde_sqrt, covInv_sqrt).permute(0,2,1)
                if approx_order is None:
                    Exp_vec = Exp(vec[:,self.pos_dim:])
                elif approx_order == 1:
                    Exp_vec = torch.eye(self.pos_dim).cuda().view(1,3,3) + vector2tensor_1dim(vec[:,self.pos_dim:])
                else:
                    ### consider up to the second order
                    dcov = vector2tensor_1dim(vec[:,self.pos_dim:])
                    Exp_vec = torch.eye(self.pos_dim).cuda().view(1,3,3) + dcov + 0.5*torch.bmm(dcov, dcov)
                dX = group_action(Exp_vec, temp)
                diff_cov_sq = 0.5*((dX-self.Eye)**2).sum((1,2))
                #diff_cov_sq = (0.5*dX**2 - dX*torch.eye(self.pos_dim).cuda().view(1,self.pos_dim,self.pos_dim)).sum((1,2)) \
                #+ 0.5*self.cov_dim*x.shape[0]
                diff2 = diff_pos_sq + diff_cov_sq
                
                if weight is None:
                    lossSum += torch.sum(diff2) / N
                else:
                    lossSum += torch.sum(diff2*weight) / N
        return lossSum / Niter