### variation from https://github.com/mattnedrich/MeanShift_py
from abc import ABC, abstractmethod
import torch
import point_grouper_sph as pg
import time
from sph_n_ambient import *
from sph_n_DataUtil import *
from util import *

MIN_DISTANCE = 0.000001


class MeanShift(ABC):
    ####################### abstract class for mean shift #########################
    def __init__(self, step_size):
        super(MeanShift, self).__init__()
        self.step_size = step_size
        #### point grouper should be carefully chosen for each derivative class for appropriate distance calculations...
        self.point_grouper = pg.PointGrouper()

    def cluster(self, points, max_iter, returnShiftedOnly=False, printEpochPeriod = 100, loggingFileName = None, 
                mode = 'small_memory', distTol = 0.1):
        if loggingFileName is None:
            logger = None
        else:
            logger = set_logger(loggingFileName)
            
        shift_points = points.clone()
        converged = False
        iter_num = 0
        still_shifting = torch.BoolTensor([True] * points.shape[0])
        if points.is_cuda:
            still_shifting = still_shifting.cuda()
        dist = torch.FloatTensor(len(points)).fill_(0)
        
        start_time = time.time()
        while not converged and iter_num < max_iter:
            # print max_min_dist
            iter_num += 1
            
            curPoints = shift_points[still_shifting]
            if iter_num % printEpochPeriod == 0:
                print_info("iter: {:d} ---- time {:.1f} ----  num points to shift: {:d}".format(iter_num, time.time() - start_time, curPoints.shape[0]), logger)
            newPoints = self._shift_point(curPoints)
            dist = self._get_dist(newPoints, curPoints)
            
            shift_points[still_shifting] = newPoints
            still_shifting[still_shifting] = self._determine_to_shift(dist)
            if still_shifting.sum() == 0:
                converged = True
            else:
                converged = self._convergence_check(dist)
            
        if converged:
            print_info("shifting finished at iter_num {:d}".format(iter_num), logger)
        else:
            print_info("shifting exceeded max_iter... {:d} points not converged with max, mean, min distances {:f}, {:f}, {:f}".format(torch.sum(still_shifting).item(), torch.max(dist), torch.mean(dist), torch.min(dist)), logger)
        
        if returnShiftedOnly:
            return shift_points, still_shifting
        
        original_points = points.cpu().numpy()
        shifted_points = shift_points.cpu().numpy()
        still_shifting = still_shifting.cpu().numpy()
        
        print_info("\n start point grouping...", logger)
        # point grouper takes quite a time for many data points... 
        group_assignments = self.point_grouper.group_points(shift_points, printEpochPeriod = printEpochPeriod, 
                                                            loggingFileName = loggingFileName, mode = mode, distTol = distTol)
        return MeanShiftResult(original_points, shifted_points, group_assignments.cpu().numpy(), still_shifting)
    
    def _convergence_check(self, dist):
        return torch.max(dist) < MIN_DISTANCE
    
    def _determine_to_shift(self, dist):
        return dist > MIN_DISTANCE
    
    def _get_dist(self, newPoints, curPoints):
        diff = newPoints - curPoints
        return torch.sqrt(torch.sum(diff*diff, dim=1))
    
    def _shift_point(self, points):
        pass
    
class MeanShift_dae(MeanShift):
    def __init__(self, model, step_size = 1.0):
        super(MeanShift_dae, self).__init__(step_size)
        self.model = model
        
    def _shift_point(self, points):
        dx = self.model.autoencoder(points).data
        return points + self.step_size * dx
    
class MeanShift_dae_sph_ambient(MeanShift):
    def __init__(self, model, step_size = 1.0):
        super(MeanShift_dae_sph_ambient, self).__init__(step_size)
        self.model = model
        
    def _shift_point(self, points):
        dx = self.model.autoencoder(points).data
        return project_to_sphere(points + self.step_size * dx)

class MeanShift_lsldg(MeanShift):
    def __init__(self, model, step_size = 1.0):
        super(MeanShift_lsldg, self).__init__(step_size)
        self.model = model
        
    def _shift_point(self, points):
        if self.step_size == 1.0:
            return self.model.meanShiftUpdate(points)
        points_new = self.model.meanShiftUpdate(points)
        dx = points_new - points
        return points + self.step_size * dx

class MeanShift_gdae_sph_ambient(MeanShift):
    def __init__(self, model, step_size = 1.0):
        super(MeanShift_gdae_sph_ambient, self).__init__(step_size)
        self.model = model
        self.point_grouper = pg.PointGrouper_sph_ambient()
        
    def _shift_point(self, points):
        if self.step_size == 1.0:
            return self.model.clean_forward(points).data
        r = self.model.clean_forward(points).data
        v = logarithm_map(points, r)
        return exponential_map(points, self.step_size*v)

    def _get_dist(self, newPoints, curPoints):
        return distance(newPoints, curPoints)

    
class MeanShift_rlsldg(MeanShift):
    def __init__(self, model, step_size = 1.0, memory_efficient = False):
        super(MeanShift_rlsldg, self).__init__(step_size)
        self.model = model
        self.point_grouper = pg.PointGrouper_sph_ambient()
        self.memory_efficient = memory_efficient
        
    def _shift_point(self, points):
        # inputs are represented in the ambient space
        if self.step_size == 1.0:
            return self.model.meanShiftUpdate(points, memory_efficient = self.memory_efficient)
        log_x_r = self.model.meanShiftUpdate(points, returnLog=True, memory_efficient = self.memory_efficient)
        return exponential_map(points, self.step_size*log_x_r)
    
    def _get_dist(self, newPoints, curPoints):
        return distance(newPoints, curPoints)
    
    
class MeanShift_rlsldgInCoord(MeanShift):
    def __init__(self, model, step_size = 1.0, memory_efficient = False):
        super(MeanShift_rlsldgInCoord, self).__init__(step_size)
        self.model = model
        self.point_grouper = pg.PointGrouper_sph()
        self.memory_efficient = memory_efficient
        
    def _shift_point(self, points):
        # inputs are represented in spherical coordinates
        if self.step_size == 1.0:
            return self.model.meanShiftUpdateInCoord(points, memory_efficient = self.memory_efficient)
        pos_x, log_x_r = self.model.meanShiftUpdateInCoord(points, returnCurPosAndLog=True, memory_efficient = self.memory_efficient)
        return getCoord_torch(exponential_map(pos_x, self.step_size*log_x_r))
    
    def _get_dist(self, newPoints, curPoints):
        # inputs are represented in spherical coordiantes
        pos1 = getPos_torch(newPoints)
        pos2 = getPos_torch(curPoints)
        temp = (pos2 * pos1).sum(1)
        temp[temp > 1] = 1
        temp[temp < -1] = -1
        return torch.acos(temp)
    
class MeanShiftResult:
    def __init__(self, original_points, shifted_points, cluster_ids, still_shifting):
        self.original_points = original_points
        self.shifted_points = shifted_points
        self.cluster_ids = cluster_ids
        self.still_shifting = still_shifting