### variation from https://github.com/mattnedrich/MeanShift_py
import sys
import numpy as np
import torch
import random
from tensor_data_util import posMetric_sqrt_func

####### determine nearest group based on maximum distances

class pointGrouper:
    def __init__(self, pos_group_threshold, eliminate_threshold):
        self.pos_group_threshold = pos_group_threshold
        self.eliminate_threshold = eliminate_threshold
        
    def group_points(self, shiftedPoints):
        
        ############## assign each point to groups ###############
        groups = []
        groups_pointIdx_ori = []
        group_idx = 0
        idxSet = [i for i in range(shiftedPoints.shape[0])]
        random.seed(230)
        random.shuffle(idxSet)
        group_assignment = np.array([0]*shiftedPoints.shape[0]) 
        for i in range(shiftedPoints.shape[0]):
            if (i+1) % 100 == 0:
                print("assigning point %d, cur number of groups: %d ..." % (i+1, len(groups)))
            point = shiftedPoints[idxSet[i]]
            nearest_group_idx = self.get_nearest_group(point, groups)
            if nearest_group_idx is None:
                # create new group
                groups.append([point.view(1,self.dim)])
                groups_pointIdx_ori.append([idxSet[i]])
                group_assignment[idxSet[i]] = group_idx
                group_idx += 1
            else:
                groups[nearest_group_idx].append(point.view(1,self.dim))
                groups_pointIdx_ori[nearest_group_idx].append(idxSet[i])
                group_assignment[idxSet[i]] = nearest_group_idx
        ##########################################################
        ###### eliminate groups with small number of points ######
        groups_pointIdx = groups_pointIdx_ori
        group_assignment_ori = group_assignment.copy()
        if self.eliminate_threshold is not None:
            groups_pointIdx = []
            for i, group in enumerate(groups):
                if len(group) < self.eliminate_threshold:
                    groups_pointIdx.append([])
                    for idx in groups_pointIdx_ori[i]:
                        group_assignment[idx] = -1
                else:
                    groups_pointIdx.append(groups_pointIdx_ori[i])
        ##########################################################
        return np.array(group_assignment), groups_pointIdx, np.array(group_assignment_ori), groups_pointIdx_ori
    
    def get_nearest_group(self, point, groups):
        nearest_group_idx = None
        for i, group in enumerate(groups):
            pos_dist_max = self.distance_to_group(point, group)
            if pos_dist_max < self.pos_group_threshold:
                nearest_group_idx = i
                break             ## is this appropriate??
        
        return nearest_group_idx
    
    def distance_to_group(self, point, group):
        dist = sys.float_info.max
        if len(group) > 0:
            point = point.view(1,self.dim).expand(len(group),-1)
            group_ = torch.cat(group, 0)
            ###### temporarily use Euclidean metric for position ######
            pos_dist_max = torch.max(torch.sqrt(torch.sum((point - group_)**2, 1)))
            ###########################################################
        return pos_dist_max

class DTI_pointGrouper:
    def __init__(self, pos_group_threshold, cov_group_threshold, eliminate_threshold, pos_dim, pos_metric, 
                pos_dist_calc_mode):
        self.pos_group_threshold = pos_group_threshold
        self.cov_group_threshold = cov_group_threshold
        self.eliminate_threshold = eliminate_threshold
        self.pos_dim = pos_dim
        self.dim = self.pos_dim + int(pos_dim*(pos_dim+1)/2)
        self.pos_metric = pos_metric
        self.pos_dist_calc_mode = pos_dist_calc_mode
        
    def group_points(self, shiftedPoints):
        ############## assign each point to groups ###############
        groups = []
        groups_pointIdx_ori = []
        group_idx = 0
        idxSet = [i for i in range(shiftedPoints.shape[0])]
        random.seed(230)
        random.shuffle(idxSet)
        group_assignment = np.array([0]*shiftedPoints.shape[0]) 
        for i in range(shiftedPoints.shape[0]):
            if (i+1) % 100 == 0:
                print("assigning point %d, cur number of groups: %d ..." % (i+1, len(groups)))
            point = shiftedPoints[idxSet[i]]
            nearest_group_idx = self.get_nearest_group(point, groups)
            if nearest_group_idx is None:
                # create new group
                groups.append([point.view(1,self.dim)])
                groups_pointIdx_ori.append([idxSet[i]])
                group_assignment[idxSet[i]] = group_idx
                group_idx += 1
            else:
                groups[nearest_group_idx].append(point.view(1,self.dim))
                groups_pointIdx_ori[nearest_group_idx].append(idxSet[i])
                group_assignment[idxSet[i]] = nearest_group_idx
        ##########################################################
        ###### eliminate groups with small number of points ######
        groups_pointIdx = groups_pointIdx_ori
        group_assignment_ori = group_assignment.copy()
        if self.eliminate_threshold is not None:
            groups_pointIdx = []
            for i, group in enumerate(groups):
                if len(group) < self.eliminate_threshold:
                    groups_pointIdx.append([])
                    for idx in groups_pointIdx_ori[i]:
                        group_assignment[idx] = -1
                else:
                    groups_pointIdx.append(groups_pointIdx_ori[i])
        ##########################################################
        return np.array(group_assignment), groups_pointIdx, np.array(group_assignment_ori), groups_pointIdx_ori
    
    def get_nearest_group(self, point, groups):
        nearest_group_idx = None
        #testPoints = []
        #testGroupIdx = []
        for i, group in enumerate(groups):
            cov_dist_max = self.cov_distance_to_group(point, group)
            if cov_dist_max < self.cov_group_threshold:
                pos_dist = self.pos_distance_to_group(point, group)
                if pos_dist < self.pos_group_threshold:
                    nearest_group_idx = i
                    break             ## is this appropriate??
                #testPoints.append(torch.cat(group,0))
                #testGroupIdx += [i]*len(group)
        """
        ### find nearest group
        if len(testPoints) > 0:
            testPoints = torch.cat(testPoints, 0)
            point = point.view(1,self.dim)
            if self.pos_metric is None:
                distmat = torch.sqrt((point[:,:self.pos_dim] - testPoints[:,:self.pos_dim])**2).sum((1))
            else:
                posMetric_sqrt = posMetric_sqrt_func(0.5*(point + testPoints))
                distmat = torch.sqrt(
                        (
                            torch.bmm((point[:,:self.pos_dim] - testPoints[:,:self.pos_dim]
                                      ).view(testPoints.shape[0], 1, self.pos_dim), posMetric_sqrt)**2
                        ).sum((1,2))
                    )
            if torch.min(distmat) < self.pos_group_threshold:
                min_pos_dist_idx = torch.argmin(distmat)
                nearest_group_idx = testGroupIdx[min_pos_dist_idx]
        """
        return nearest_group_idx
    
    def pos_distance_to_group(self, point, group):
        if len(group) > 0:
            point = point.view(1,self.dim).expand(len(group),-1)
            group_ = torch.cat(group, 0)
            if self.pos_metric is None:
                ###### use Euclidean metric for position ######
                distmat = torch.sqrt((point[:,:self.pos_dim] - group_[:,:self.pos_dim])**2).sum((1))
                ###############################################
            else:
                posMetric_sqrt = posMetric_sqrt_func(0.5*(point + group_))
                distmat = torch.sqrt(
                    (
                        torch.bmm((point[:,:self.pos_dim] - group_[:,:self.pos_dim]
                                  ).view(len(group), 1, self.pos_dim), posMetric_sqrt)**2
                    ).sum((1,2))
                )
            if self.pos_dist_calc_mode == 'max':
                pos_dist = torch.max(distmat)
            elif self.pos_dist_calc_mode == 'mean':
                pos_dist = torch.mean(distmat)
            else:
                pos_dist = torch.min(distmat)
        return pos_dist
    
    def cov_distance_to_group(self, point, group):
        dist = sys.float_info.max
        if len(group) > 0:
            point = point.view(1,self.dim).expand(len(group),-1)
            group_ = torch.cat(group, 0)
            cov_dist_max = torch.max(torch.sqrt(torch.sum((point[:,self.pos_dim:] - group_[:,self.pos_dim:])**2, 1)))
        return cov_dist_max
