### variation from https://github.com/mattnedrich/MeanShift_py
from abc import ABC, abstractmethod
import torch
import point_grouper_Pn as pg
import time
from util import *
from Pn_util import *

MIN_DISTANCE = 0.000001

def get_quantities_P_n(curPoints):
    eps1 = 1e-14
    S, U = batch_eigsym(curPoints)
    S[S<eps1] = eps1
    X_sqrt, X_invsqrt = get_sqrt_sym(curPoints, returnInvAlso = True, S = S, U = U)
    return S, U, X_sqrt, X_invsqrt

def get_dist_P_n(newPoints, curPoints, quantities):
    eps2 = 1e-7
    if len(quantities) != 4:
        raise Exception("check input quantities")
    (_, _, _, curPoints_invsqrt) = quantities
    T = torch.bmm(torch.bmm(curPoints_invsqrt, newPoints), curPoints_invsqrt)
    S_T, _ = batch_eigsym(T)
    S_T[S_T<eps2] = eps2
    logS = torch.log(S_T)
    return torch.sqrt(torch.sum(logS*logS, dim=-1))

class MeanShift(ABC):
    ####################### abstract class for mean shift #########################
    def __init__(self, step_size):
        super(MeanShift, self).__init__()
        self.step_size = step_size
        #### point grouper should be carefully chosen for each derivative class for appropriate distance calculations...
        self.point_grouper = pg.PointGrouper()

    def cluster(self, points, max_iter, returnShiftedOnly=False, printEpochPeriod = 100, loggingFileName = None, mode = 'small_memory', distTol = 0.1):
        if loggingFileName is None:
            logger = None
        else:
            logger = set_logger(loggingFileName)
            
        shift_points = points.clone()
        converged = False
        iter_num = 0
        still_shifting = torch.BoolTensor([True] * points.shape[0])
        if points.is_cuda:
            still_shifting = still_shifting.cuda()
            
        start_time = time.time()
        while not converged and iter_num < max_iter:
            # print max_min_dist
            iter_num += 1
            curPoints = shift_points[still_shifting]
            if iter_num % printEpochPeriod == 0:
                print_info("iter: {:d} ---- time {:.1f} ----  num points to shift: {:d}".format(iter_num, time.time() - start_time, curPoints.shape[0]), logger)
            
            quantities = self._get_quantities(curPoints)
            newPoints = self._shift_point(curPoints, quantities = quantities)
            dist = self._get_dist(newPoints, curPoints, curPoints_quantities = quantities)
            
            shift_points[still_shifting] = newPoints
            still_shifting[still_shifting] = self._determine_to_shift(dist)
            if still_shifting.sum() == 0:
                converged = True
            else:
                converged = self._convergence_check(dist)
            
        if converged:
            print_info("shifting finished at iter_num {:d}".format(iter_num), logger)
        else:
            print_info("shifting exceeded max_iter... {:d} points not converged with max, mean, min distances {:f}, {:f}, {:f}".format(torch.sum(still_shifting).item(), torch.max(dist), torch.mean(dist), torch.min(dist)), logger)
        
        if returnShiftedOnly:
            return shift_points, still_shifting
        
        original_points = points.cpu().numpy()
        shifted_points = shift_points.cpu().numpy()
        still_shifting = still_shifting.cpu().numpy()
        
        print_info("\n start point grouping...", logger)
        # point grouper takes quite a time for many data points... 
        group_assignments = self.point_grouper.group_points(shift_points, printEpochPeriod = printEpochPeriod, 
                                                            loggingFileName = loggingFileName, 
                                                           mode = mode)
        return MeanShiftResult(original_points, shifted_points, group_assignments.cpu().numpy(), still_shifting)
    
    def _convergence_check(self, dist):
        return torch.max(dist) < MIN_DISTANCE
    
    def _determine_to_shift(self, dist):
        return dist > MIN_DISTANCE
    
    def _get_dist(self, newPoints, curPoints, curPoints_quantities = None):
        diff = newPoints - curPoints
        return torch.sqrt(torch.sum(diff*diff, dim=1))
    
    def _shift_point(self, points, quantities = None):
        pass
    
    def _get_quantities(self, curPoints):
        return None
    
class MeanShift_dae(MeanShift):
    def __init__(self, model, step_size = 1.0):
        super(MeanShift_dae, self).__init__(step_size)
        self.model = model
        
    def _shift_point(self, points, quantities = None):
        dx = self.model.autoencoder(points).data
        return points + self.step_size * dx

class MeanShift_lsldg(MeanShift):
    def __init__(self, model, step_size = 1.0):
        super(MeanShift_lsldg, self).__init__(step_size)
        self.model = model
        
    def _shift_point(self, points, quantities = None):
        if self.step_size == 1.0:
            return self.model.meanShiftUpdate(points)
        points_new = self.model.meanShiftUpdate(points)
        dx = points_new - points
        return points + self.step_size * dx

class MeanShift_gdae_P_n_fromLog(MeanShift):
    def __init__(self, model, step_size = 1.0, update_approx_order = None):
        super(MeanShift_gdae_P_n_fromLog, self).__init__(step_size)
        self.model = model
        self.point_grouper = pg.PointGrouper_P_n()
        self.update_approx_order = update_approx_order
        
    def _get_quantities(self, curPoints):
        return get_quantities_P_n(curPoints)
    
    def _shift_point(self, points, quantities):
        # set input points and output as matrices
        if len(quantities) != 4:
            raise Exception("check input quantities")
        S, U, X_sqrt, X_invsqrt = quantities
        log_x = mat2vec(Log_mat(points, S = S, U = U))
        v = self.model.forward_autoencoder(log_x).data
        if self.update_approx_order is not None:
            Exp_v = Exp_vec_approx(self.step_size * v, approx = self.update_approx_order)
        else:
            Exp_v = Exp_vec(self.step_size * v)
        return torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)

    def _get_dist(self, newPoints, curPoints, curPoints_quantities):
        return get_dist_P_n(newPoints, curPoints, curPoints_quantities)

class MeanShift_rlsldg(MeanShift):
    def __init__(self, model, step_size = 1.0, update_approx_order = None):
        super(MeanShift_rlsldg, self).__init__(step_size)
        self.model = model
        self.point_grouper = pg.PointGrouper_P_n()
        self.update_approx_order = update_approx_order
        
    def _get_quantities(self, curPoints):
        return get_quantities_P_n(curPoints)
        
    def _shift_point(self, points, quantities):
        if len(quantities) != 4:
            raise Exception("check input quantities")
        S, U, X_sqrt, X_invsqrt = quantities
        log_x_r = self.model.meanShiftUpdate(points, returnDir = True, quantities_for_x = quantities)
        v = mat2vec(torch.bmm(torch.bmm(X_invsqrt, vec2mat(self.step_size * log_x_r)), X_invsqrt))
        if self.update_approx_order is not None:
            Exp_v = Exp_vec_approx(self.step_size * v, approx = self.update_approx_order)
        else:
            Exp_v = Exp_vec(self.step_size * v)
        return torch.bmm(torch.bmm(X_sqrt, Exp_v), X_sqrt)
    
    def _get_dist(self, newPoints, curPoints, curPoints_quantities):
        return get_dist_P_n(newPoints, curPoints, curPoints_quantities)
    
class MeanShiftResult:
    def __init__(self, original_points, shifted_points, cluster_ids, still_shifting):
        self.original_points = original_points
        self.shifted_points = shifted_points
        self.cluster_ids = cluster_ids
        self.still_shifting = still_shifting