### variation from https://github.com/mattnedrich/MeanShift_py
import os
import numpy as np
import torch
from abc import ABC, abstractmethod
import time
from DTI_pointGrouper2 import DTI_pointGrouper, pointGrouper
from torch_batch_svd import svd
from tensor_data_util import Exp, Log, posMetric_sqrt_func, get_sqrt, get_sqrt_sym, group_action
from data_util import Log_vec2Log, Log2Log_vec, tensor2vector_1dim, vector2tensor_1dim
import scipy.io as sio

class DTI_meanShift(ABC):
    ####################### abstract class for DTI mean shift #########################
    def __init__(self, step_size, pos_threshold = 0.2, cov_threshold = 0.02, 
                 pos_group_threshold = 2, cov_group_threshold = 0.2, eliminate_threshold = None,
                lowerBound = None, upperBound = None, pos_dim = 3, pos_metric = None, pos_dist_calc_mode = 'mean',
                savefolder = './test_meanshiftIter/'):
        super(DTI_meanShift, self).__init__()
        self.step_size = step_size
        self.pos_threshold = pos_threshold
        self.cov_threshold = cov_threshold
        self.pos_group_threshold = pos_group_threshold
        self.cov_group_threshold = cov_group_threshold
        # ignore groups with the number of components less than this threshold
        self.eliminate_threshold = eliminate_threshold
        self.lowerBound = lowerBound
        self.upperBound = upperBound
        self.pos_dim = pos_dim
        self.cov_dim = int(self.pos_dim*(self.pos_dim+1)/2)
        self.dim = self.pos_dim + self.cov_dim
        self.pos_metric = pos_metric
        self.pos_dist_calc_mode = pos_dist_calc_mode
        self.savefolder = savefolder
            
    def run_meanShift(self, dataPoints, max_iter, save_iter = None, pos_metric = None, save_prefix = '', 
                      cleanInput = None, error_weight = None, save_shiftedPoints = False):
        #### correspond to filtering when cleanInput is not None
        if cleanInput is not None:
            le_error = []
            le_wt_error = []
            ai_error = []
            ai_wt_error = []
            shiftedPointsSet = []
            cleanLog_vec = self.get_cleanLog_vec(cleanInput)
            cleanCovInv_sqrt = self.get_cleanCovInv_sqrt(cleanInput)
            
        start_time = time.time()
        shiftedPoints = dataPoints.clone()
        max_pos_dist = self.pos_threshold + 1
        max_cov_dist = self.cov_threshold + 1
        iter_num = 0
        converged = [False]*shiftedPoints.shape[0]
        
        ###################### mean shift algorithm ######################
        while ((max_pos_dist > self.pos_threshold or max_cov_dist > self.cov_threshold) 
               and iter_num < max_iter):
            iter_num += 1
            curPoints = []
            for i in range(shiftedPoints.shape[0]):
                if not converged[i]:
                    curPoints.append(shiftedPoints[i].view(1,self.dim))
            if len(curPoints) == 0:
                break
            curPoints = torch.cat(curPoints, 0)
            print("iter: %d, num points to shift: %d" % (iter_num, curPoints.shape[0]))
            
            curPoints_shifted, cov_dist, cur_posMetric_sqrt = self.shiftPoints_and_calculateCovDist_and_PosMetric(curPoints)
            
            ###### update current points and check convergence ######
            k = 0
            pos_dist_calculated = False
            for i in range(shiftedPoints.shape[0]):
                if not converged[i]:
                    shiftedPoints[i] = curPoints_shifted[k]
                    if cov_dist[k] < self.cov_threshold:
                        ## calculate distance only for the points whose covariances are converged ##
                        cur_pos_dist = self.get_pos_distance(curPoints[k], curPoints_shifted[k], cur_posMetric_sqrt[k])
                        if not pos_dist_calculated or cur_pos_dist > max_pos_dist:
                            max_pos_dist = cur_pos_dist
                        pos_dist_calculated = True
                        if cur_pos_dist < self.pos_threshold:
                            converged[i] = True
                        ############################################################################
                    k += 1
            if save_iter is not None and (iter_num % save_iter == 0) and save_shiftedPoints:
                self.save_shiftedPoints(iter_num, shiftedPoints, save_prefix)
            #########################################################
            
            ##################### save filtering results 
            if cleanInput is not None:
                cur_error, cur_wt_error = self.get_logEuclidean_error(shiftedPoints, cleanLog_vec, 
                                                                      error_weight = error_weight)
                le_error.append(cur_error.view(1))
                le_wt_error.append(cur_wt_error.view(1))
                cur_error, cur_wt_error = self.get_affineInvariant_error(shiftedPoints, cleanCovInv_sqrt, 
                                                                      error_weight = error_weight)
                ai_error.append(cur_error.view(1))
                ai_wt_error.append(cur_wt_error.view(1))
                shiftedPointsSet.append(shiftedPoints.clone())
            ########################################################
            
        print("mean shift terminated !!! elapsed time: %.1f" % (time.time() - start_time))
        ##################################################################
        
        if cleanInput is not None:
            le_error = torch.cat(le_error, 0)
            le_wt_error = torch.cat(le_wt_error, 0)
            
            if save_shiftedPoints:
                iter_num = torch.argmin(le_error)
                self.save_shiftedPoints(iter_num.cpu().data.numpy(), shiftedPointsSet[iter_num], 'min_le_e_'+save_prefix)
                iter_num = torch.argmin(le_wt_error)
                self.save_shiftedPoints(iter_num.cpu().data.numpy(), shiftedPointsSet[iter_num], 'min_le_wte_'+save_prefix)
            
            ai_error = torch.cat(ai_error, 0)
            ai_wt_error = torch.cat(ai_wt_error, 0)
            
            if save_shiftedPoints:
                iter_num = torch.argmin(ai_error)
                self.save_shiftedPoints(iter_num.cpu().data.numpy(), shiftedPointsSet[iter_num], 'min_ai_e_'+save_prefix)
                iter_num = torch.argmin(le_wt_error)
                self.save_shiftedPoints(iter_num.cpu().data.numpy(), shiftedPointsSet[iter_num], 'min_ai_wte_'+save_prefix)
            
            filteringResult = meanShift_filtering_result(shiftedPointsSet, le_error, le_wt_error, 
                                                         ai_error, ai_wt_error, converged)
            return filteringResult
        return dataPoints, shiftedPoints, converged
    
    def run_clustering(self, shiftedPoints, pos_metric = None, pos_dist_calc_mode = 'mean'):
        ########################### clustering ###########################
        start_time = time.time()
        point_grouper = DTI_pointGrouper(self.pos_group_threshold, self.cov_group_threshold, 
                                        self.eliminate_threshold, self.pos_dim, pos_metric, pos_dist_calc_mode)
        clusterIds, groups_pointIdx, clusterIds_ori, groups_pointIdx_ori = point_grouper.group_points(shiftedPoints)
        print("grouping terminated !!! elapsed time: %.1f" % (time.time() - start_time))
        ##################################################################
        
        return clusterIds, groups_pointIdx, clusterIds_ori, groups_pointIdx_ori
    
    def run(self, dataPoints, max_iter):
        dataPoints, shiftedPoints, converged = self.run_meanShift(dataPoints, max_iter, pos_metric = self.pos_metric)
        clusterIds, groups_pointIdx, _, _ = self.run_clustering(shiftedPoints, pos_metric = self.pos_metric, 
                                                               pos_dist_calc_mode = self.pos_dist_calc_mode)
        
        return meanShift_result(dataPoints, shiftedPoints, converged, clusterIds, groups_pointIdx)
    
    def get_pos_distance(self, curPoint, curPoint_shifted, cur_posMetric_sqrt):
        if cur_posMetric_sqrt is None:
            ###### temporarily use Euclidean metric for position ######
            cur_pos_dist = curPoint_shifted[:self.pos_dim] - curPoint[:self.pos_dim]
            cur_pos_dist = torch.sqrt(torch.sum(cur_pos_dist**2))
        else:
            cur_pos_dist = torch.mm(
                (curPoint_shifted[:self.pos_dim] - curPoint[:self.pos_dim]).view(1,self.pos_dim), 
                cur_posMetric_sqrt
            )
            cur_pos_dist = torch.sqrt(torch.sum(cur_pos_dist**2))
        return cur_pos_dist
    
    def projectPointsIntoBoundary(self, curPoints_shifted):
        for i in range(self.dim):
            if self.lowerBound[i] is not None:
                curPoints_shifted[curPoints_shifted[:,i] < self.lowerBound[i], i] = self.lowerBound[i]
        for i in range(self.dim):
            if self.upperBound[i] is not None:
                curPoints_shifted[curPoints_shifted[:,i] > self.upperBound[i], i] = self.upperBound[i]
        return curPoints_shifted
    
    def shiftPoints_and_calculateCovDist_and_PosMetric(self, curPoints):
        ### shift points
        curPoints_shifted = self.shift_points(curPoints)
        curPoints_shifted = self.projectPointsIntoBoundary(curPoints_shifted)
        
        ### calculate cov distance
        cov_dist = curPoints_shifted[:,self.pos_dim:] - curPoints[:,self.pos_dim:]
        cov_dist = torch.sqrt(torch.sum(cov_dist**2, dim=1))
        print(torch.max(cov_dist))
        print(torch.min(cov_dist))
        
        ### calculate pos metric sqrt
        if self.pos_metric is None:
            cur_posMetric_sqrt =  [None]*curPoints.shape[0]
        else:
            cur_posMetric_sqrt =  posMetric_sqrt_func(curPoints)
        
        return curPoints_shifted, cov_dist, cur_posMetric_sqrt
    
    def get_cleanLog_vec(self, cleanInput):
        return cleanInput[:,self.pos_dim:]
    
    def get_logEuclidean_error(self, shiftedPoints, cleanLog_vec, error_weight = None):
        error = torch.sum((shiftedPoints[:,self.pos_dim:] - cleanLog_vec)**2)/cleanLog_vec.shape[0]
        if error_weight is not None:
            wt_error = torch.sum((shiftedPoints[:,self.pos_dim:] 
                                  - cleanLog_vec)**2*error_weight)/cleanLog_vec.shape[0]
        else:
            wt_error = error
        return error, wt_error
    
    def get_cleanCovInv_sqrt(self, cleanInput):
        if cleanInput.is_cuda:
            _, cleanCovInv_sqrt = get_sqrt(Exp(Log_vec2Log(cleanInput[:,self.pos_dim:])), returnInvAlso = True)
        else:
            _, cleanCovInv_sqrt = get_sqrt(Exp(Log_vec2Log(cleanInput[:,self.pos_dim:].cuda())), returnInvAlso = True)
            cleanCovInv_sqrt = cleanCovInv_sqrt.cpu()
        return cleanCovInv_sqrt
    
    def get_affineInvariant_error(self, shiftedPoints, cleanCovInv_sqrt, error_weight = None):
        if shiftedPoints.is_cuda:
            log_val = Log(group_action(Exp(Log_vec2Log(shiftedPoints[:,self.pos_dim:])), 
                                   cleanCovInv_sqrt.permute(0,2,1), returnVec = True))
        else:
            log_val = Log(group_action(
                Exp(Log_vec2Log(shiftedPoints[:,self.pos_dim:].cuda())), 
                                   cleanCovInv_sqrt.permute(0,2,1).cuda(), returnVec = True)).cpu()
        error_vec = Log2Log_vec(log_val)
        error = torch.sum(error_vec**2)/cleanCovInv_sqrt.shape[0]
        if error_weight is not None:
            wt_error = torch.sum(error_vec**2*error_weight)/cleanCovInv_sqrt.shape[0]
        else:
            wt_error = error
        return error, wt_error
    
    @abstractmethod
    def shift_points(self, points):
        pass
    
    @abstractmethod
    def save_shiftedPoints(self, iter_num, shiftedPoints, save_prefix = ''):
        pass

        
class DTI_meanShift_dae(DTI_meanShift):
    def __init__(self, dae_model, step_size, pos_threshold = 0.2, cov_threshold = 0.02,
                pos_group_threshold = 2, cov_group_threshold = 0.2, eliminate_threshold = None,
                lowerBound = None, upperBound = None, pos_dim = 3, pos_metric = None, pos_dist_calc_mode = 'mean',
                savefolder = './test_meanshiftIter/'):
        self.dae_model = dae_model
        super(DTI_meanShift_dae, self).__init__(step_size, pos_threshold, cov_threshold,
                                           pos_group_threshold, cov_group_threshold, eliminate_threshold,
                                           lowerBound, upperBound, pos_dim, pos_metric, pos_dist_calc_mode,
                                               savefolder)
        self.pos_dim = self.dae_model.pos_dim
        
    def shift_points(self, points):
        with torch.no_grad():
            dr = self.dae_model.autoencoder(points)
        return points + self.step_size * dr
    
    def save_shiftedPoints(self, iter_num, shiftedPoints, save_prefix = ''):
        if self.pos_dim == 2:
            if hasattr(self.dae_model, 'cov_metric_coeff'):
                savefilename = +save_prefix+'gdae_DTI2dim'+'_std' + str(self.dae_model.noise_std) + \
            '_covmet' + str(self.dae_model.cov_metric_coeff) + '_stepsize' + str(self.step_size) + \
            '_iter' + str(iter_num) + '_shiftedTensor.pt'
            else:
                savefilename = self.savefolder+save_prefix+'dae_DTI2dim'+'_std' + str(self.dae_model.noise_std) + \
            '_stepsize' + str(self.step_size) + '_iter' + str(iter_num) + '_shiftedTensor.pt'
                
            torch.save(shiftedPoints, savefilename)
        else:
            if hasattr(self.dae_model, 'cov_metric_coeff'):
                savefilename = self.savefolder+save_prefix+'gdae_DTI'+'_std' + str(self.dae_model.noise_std) + \
            '_covmet' + str(self.dae_model.cov_metric_coeff) + '_stepsize' + str(self.step_size) + \
            '_iter' + str(iter_num) + '_shiftedTensor.mat'
            else:
                savefilename = self.savefolder+save_prefix+'dae_DTI'+'_std' + str(self.dae_model.noise_std) + \
            '_stepsize' + str(self.step_size) + '_iter' + str(iter_num) + '_shiftedTensor.mat'
                
            sio.savemat(savefilename, {'shiftedPoints':shiftedPoints.cpu().numpy()})
        return
    
class N_n_meanShift_dae(DTI_meanShift):
    def __init__(self, dae_model, step_size, pos_threshold = 0.2, cov_threshold = 0.02,
                pos_group_threshold = 2, cov_group_threshold = 0.2, eliminate_threshold = None,
                lowerBound = None, upperBound = None, pos_dim = 3, pos_metric = None, pos_dist_calc_mode = 'mean',
                savefolder = './test_meanshiftIter/'):
        self.dae_model = dae_model
        super(N_n_meanShift_dae, self).__init__(step_size, pos_threshold, cov_threshold,
                                           pos_group_threshold, cov_group_threshold, eliminate_threshold,
                                           lowerBound, upperBound, pos_dim, pos_metric, pos_dist_calc_mode,
                                               savefolder)
        self.pos_dim = self.dae_model.pos_dim
        
    def shift_points(self, points):
        points = points.clone()
        with torch.no_grad():
            assert(torch.sum(torch.isnan(points)) == 0)
            dr = self.dae_model.clean_forward(points)
            assert(torch.sum(torch.isnan(dr)) == 0)
        points[:,:self.pos_dim] += dr[:,:self.pos_dim] * self.step_size
        
        if self.dae_model.approx_order is None:
            Exp_dr = Exp(dr[:,self.pos_dim:] * self.step_size, returnVec = True)
        elif self.dae_model.approx_order == 1:
            Exp_dr = self.dae_model.Eye_vec + dr[:,self.pos_dim:] * self.step_size
        else:
            dcov = vector2tensor_1dim(dr[:,self.pos_dim:] * self.step_size)
            Exp_dr = tensor2vector_1dim(self.dae_model.Eye + dcov + 0.5*torch.bmm(dcov, dcov))
        
        if self.dae_model.use_logvec_input:
            cov_sqrt, covInv_sqrt = get_sqrt_sym(Exp(Log_vec2Log(points[:,self.pos_dim:])), returnInvAlso = True)
            points[:,self.pos_dim:] = Log2Log_vec(Log(group_action(Exp_dr, cov_sqrt.permute(0,2,1), returnVec = True)))
        else:
            cov_sqrt, covInv_sqrt = get_sqrt_sym(points[:,self.pos_dim:], returnInvAlso = True)
            points[:,self.pos_dim:] = group_action(Exp_dr, cov_sqrt.permute(0,2,1), returnVec = True)
        return points, dr, covInv_sqrt
    
    def projectPointsIntoBoundary(self, curPoints_shifted):
        ### position part
        for i in range(self.pos_dim):
            if self.lowerBound[i] is not None:
                curPoints_shifted[curPoints_shifted[:,i] < self.lowerBound[i], i] = self.lowerBound[i]
        for i in range(self.pos_dim):
            if self.upperBound[i] is not None:
                curPoints_shifted[curPoints_shifted[:,i] > self.upperBound[i], i] = self.upperBound[i]
        
        ### covariance part
        if self.dae_model.use_logvec_input:
            for i in range(self.pos_dim, self.dim):
                if self.lowerBound[i] is not None:
                    curPoints_shifted[curPoints_shifted[:,i] < self.lowerBound[i], i] = self.lowerBound[i]
            for i in range(self.pos_dim, self.dim):
                if self.upperBound[i] is not None:
                    curPoints_shifted[curPoints_shifted[:,i] > self.upperBound[i], i] = self.upperBound[i]
        else:
            U, S, V = svd(vector2tensor_1dim(curPoints_shifted[:,self.pos_dim:]))
            UtV = torch.diagonal(torch.bmm(U.permute(0,2,1),V), dim1=1, dim2=2)
            S[UtV < 0] = -(S[UtV < 0])
            if self.lowerBound[self.pos_dim] is not None:
                S[S < self.lowerBound[self.pos_dim]] = self.lowerBound[self.pos_dim]
            if self.upperBound[self.pos_dim] is not None:
                S[S > self.upperBound[self.pos_dim]] = self.upperBound[self.pos_dim]
            curPoints_shifted[:,self.pos_dim:] = tensor2vector_1dim(
                torch.bmm(
                    U,
                    torch.bmm(torch.diag_embed(S), U.permute(0,2,1))
                )
            )
        return curPoints_shifted
    
    def shiftPoints_and_calculateCovDist_and_PosMetric(self, curPoints):
        ### shift points
        curPoints_shifted, dr, covInv_sqrt = self.shift_points(curPoints)
        curPoints_shifted = self.projectPointsIntoBoundary(curPoints_shifted)
        
        ### calculate cov distance
        if self.dae_model.use_logvec_input:
            cov_dist = curPoints_shifted[:,self.pos_dim:] - curPoints[:,self.pos_dim:]
            cov_dist = torch.sqrt(torch.sum(cov_dist**2, dim=1))
        else:
            cov_dist = torch.sqrt(torch.sum((dr[:,self.pos_dim:]*self.dae_model.cov_metric_coeff_sqrt)**2, 
                                        dim=1)) * self.step_size
        print(torch.max(cov_dist))
        print(torch.min(cov_dist))
            
        ### calculate pos metric sqrt
        if self.pos_metric is None:
            cur_posMetric_sqrt =  [None]*curPoints.shape[0]
        else:
            cur_posMetric_sqrt = covInv_sqrt
        
        return curPoints_shifted, cov_dist, cur_posMetric_sqrt
    
    def get_cleanLog_vec(self, cleanInput):
        if self.dae_model.use_logvec_input:
            return cleanInput[:,self.pos_dim:]
        return Log2Log_vec(Log(cleanInput[:,self.pos_dim:]))
    
    def get_logEuclidean_error(self, shiftedPoints, cleanLog_vec, error_weight = None):
        if self.dae_model.use_logvec_input:
            shiftedLog_vec = shiftedPoints[:,self.pos_dim:]
        else:
            shiftedLog_vec = Log2Log_vec(Log(shiftedPoints[:,self.pos_dim:]))
        error = torch.sum((shiftedLog_vec - cleanLog_vec)**2)/cleanLog_vec.shape[0]
        if error_weight is not None:
            wt_error = torch.sum((shiftedLog_vec 
                                  - cleanLog_vec)**2*error_weight)/cleanLog_vec.shape[0]
        else:
            wt_error = error
        return error, wt_error
    
    def get_cleanCovInv_sqrt(self, cleanInput):
        if self.dae_model.use_logvec_input:
            _, cleanCovInv_sqrt = get_sqrt_sym(Exp(Log_vec2Log(cleanInput[:,self.pos_dim:])), returnInvAlso = True)
        else:
            _, cleanCovInv_sqrt = get_sqrt_sym(cleanInput[:,self.pos_dim:], returnInvAlso = True)
        return cleanCovInv_sqrt
    
    def get_affineInvariant_error(self, shiftedPoints, cleanCovInv_sqrt, error_weight = None):
        if self.dae_model.use_logvec_input:
            log_val = Log(group_action(Exp(Log_vec2Log(shiftedPoints[:,self.pos_dim:])), 
                                   cleanCovInv_sqrt.permute(0,2,1), returnVec = True))
        else:
            log_val = Log(group_action(shiftedPoints[:,self.pos_dim:], cleanCovInv_sqrt.permute(0,2,1), returnVec = True))
        error_vec = Log2Log_vec(log_val)
        error = torch.sum(error_vec**2)/cleanCovInv_sqrt.shape[0]
        if error_weight is not None:
            wt_error = torch.sum(error_vec**2*error_weight)/cleanCovInv_sqrt.shape[0]
        else:
            wt_error = error
        return error, wt_error
    
    def save_shiftedPoints(self, iter_num, shiftedPoints, save_prefix = ''):
        if self.pos_dim == 3:
            if hasattr(self.dae_model, 'cov_metric_coeff_sqrt'):
                savefilename = self.savefolder+save_prefix+'gdae_N_n'+'_std' + str(self.dae_model.noise_std) + \
            '_covcoeff' + str(self.dae_model.cov_coeff) + '_stepsize' + str(self.step_size) + \
            '_iter' + str(iter_num) + '_shiftedTensor.mat'
            else:
                savefilename = self.savefolder+save_prefix+'dae_N_n'+'_std' + str(self.dae_model.noise_std) + \
            '_stepsize' + str(self.step_size) + '_iter' + str(iter_num) + '_shiftedTensor.mat'
            sio.savemat(savefilename, {'shiftedPoints':shiftedPoints.cpu().numpy()})
        return

class DTI_meanShift_vectordae(DTI_meanShift):
    def __init__(self, dae_model, step_size, pos_threshold = 0.2, cov_threshold = 0.02,
                pos_group_threshold = 2, cov_group_threshold = 0.2, eliminate_threshold = None,
                lowerBound = None, upperBound = None, pos_dim = 3, pos_metric = None, pos_dist_calc_mode = 'mean',
                savefolder = './test_meanshiftIter/'):
        self.dae_model = dae_model
        super(DTI_meanShift_vectordae, self).__init__(step_size, pos_threshold, cov_threshold,
                                           pos_group_threshold, cov_group_threshold, eliminate_threshold,
                                           lowerBound, upperBound, pos_dim, pos_metric, pos_dist_calc_mode,
                                                     savefolder)
        self.pos_dim = self.dae_model.pos_dim
        self.track_log_bad_idx1 = [] # bad index obtained during Log func. when calculating logEuclidean_error
        self.track_log_bad_idx2 = [] # bad index obtained during Log func. when calculating affineInvariant_error
    def shift_points(self, points):
        with torch.no_grad():
            dr = self.dae_model.autoencoder(points)
        return points + self.step_size * dr
    
    
    def get_cleanLog_vec(self, cleanInput):
        return Log2Log_vec(Log(cleanInput[:,self.pos_dim:]))
    
    def get_logEuclidean_error(self, shiftedPoints, cleanLog_vec, error_weight = None):
        shiftedLog, badIdx = Log(shiftedPoints[:,self.pos_dim:], returnBadIdx = True)
        self.track_log_bad_idx1 += [badIdx]
        shiftedLog_vec = Log2Log_vec(shiftedLog)
        error = torch.sum((shiftedLog_vec - cleanLog_vec)**2)/cleanLog_vec.shape[0]
        if error_weight is not None:
            wt_error = torch.sum((shiftedLog_vec 
                                  - cleanLog_vec)**2*error_weight)/cleanLog_vec.shape[0]
        else:
            # return only the errors not from bad indices
            temp = (shiftedLog_vec - cleanLog_vec)**2
            wt_error = (torch.sum(temp) - torch.sum(temp[badIdx]))/(cleanLog_vec.shape[0] - badIdx.shape[0])
        return error, wt_error
    
    def get_cleanCovInv_sqrt(self, cleanInput):
        _, cleanCovInv_sqrt = get_sqrt_sym(cleanInput[:,self.pos_dim:], returnInvAlso = True)
        return cleanCovInv_sqrt
    
    def get_affineInvariant_error(self, shiftedPoints, cleanCovInv_sqrt, error_weight = None):
        log_val, badIdx = Log(group_action(shiftedPoints[:,self.pos_dim:], cleanCovInv_sqrt.permute(0,2,1), returnVec = True), returnBadIdx = True)
        self.track_log_bad_idx2 += [badIdx]
        error_vec = Log2Log_vec(log_val)
        error = torch.sum(error_vec**2)/cleanCovInv_sqrt.shape[0]
        if error_weight is not None:
            wt_error = torch.sum(error_vec**2*error_weight)/cleanCovInv_sqrt.shape[0]
        else:
            # return only the errors not from bad indices
            temp = error_vec**2
            wt_error = (torch.sum(temp) - torch.sum(temp[badIdx]))/(cleanCovInv_sqrt.shape[0] - badIdx.shape[0])
        return error, wt_error
    
    def save_shiftedPoints(self, iter_num, shiftedPoints, save_prefix = ''):
        savefilename = self.savefolder+save_prefix+'vectordae_DTI'+'_std' + str(self.dae_model.noise_std) + \
            '_stepsize' + str(self.step_size) + '_iter' + str(iter_num) + '_shiftedTensor.mat'
        sio.savemat(savefilename, {'shiftedPoints':shiftedPoints.cpu().numpy()})
        return
    
class DTI_meanShift_lsldg(DTI_meanShift):
    def __init__(self, lsldg_model, step_size, pos_threshold = 0.2, cov_threshold = 0.02,
                pos_group_threshold = 2, cov_group_threshold = 0.2, eliminate_threshold = None,
                lowerBound = None, upperBound = None, pos_dim = 3, pos_metric = None, pos_dist_calc_mode = 'mean',
                savefolder = './test_meanshiftIter/'):
        self.lsldg_model = lsldg_model
        super(DTI_meanShift_lsldg, self).__init__(step_size, pos_threshold, cov_threshold,
                                           pos_group_threshold, cov_group_threshold, eliminate_threshold,
                                           lowerBound, upperBound, pos_dim, pos_metric, pos_dist_calc_mode,
                                                 savefolder)
        
    def shift_points(self, points):
        dr = self.lsldg_model.meanShiftUpdate(points) - points
        return points + self.step_size * dr
    
    def save_shiftedPoints(self, iter_num, shiftedPoints, save_prefix = ''):
        if self.pos_dim == 2:
            savefilename = self.savefolder+save_prefix+'lsldg_DTI2dim'+\
            '_stepsize' + str(self.step_size) + \
            '_iter' + str(iter_num) + '_shiftedTensor.pt'
            torch.save(shiftedPoints, savefilename)
        else:
            savefilename = self.savefolder+save_prefix+'lsldg_DTI'+\
            '_stepsize' + str(self.step_size) + \
            '_iter' + str(iter_num) + '_shiftedTensor.mat'
            sio.savemat(savefilename, {'shiftedPoints':shiftedPoints.cpu().numpy()})
        return
    
    
class meanShift_result:
    def __init__(self, dataPoints, shiftedPoints, converged, clusterIds, groups_pointIdx):
        self.dataPoints = dataPoints
        self.shiftedPoints = shiftedPoints
        self.converged = converged
        self.clusterIds = clusterIds
        self.groups_pointIdx = groups_pointIdx

class meanShift_filtering_result:
    def __init__(self, shiftedPointsSet, le_errors, le_wt_errors, ai_errors, ai_wt_errors, converged):
        self.shiftedPointsSet = shiftedPointsSet
        self.le_errors = le_errors
        self.le_wt_errors = le_wt_errors
        self.ai_errors = ai_errors
        self.ai_wt_errors = ai_wt_errors
        self.converged = converged
    
