import sys
import numpy as np
import torch
import time
from util import *
from sph_n_DataUtil import *
from sph_n_ambient import *

class PointGrouper(object):
    def group_points(self, points, printEpochPeriod = 100, loggingFileName = None, mode = 'small_memory', distTol = 0.1):
        if loggingFileName is None:
            logger = None
        else:
            logger = set_logger(loggingFileName)
        group_assignment = []
        groups = []
        group_index = 0
        start_time = time.time()
        distMat = None
        
        if mode == 'small_memory':
            pass
        else:
            print_info("time {:.1f} --- calculate pairwise distances using gpu".format(time.time() - start_time), logger)
            distMat = self._get_pairwise_distance(points).cpu()
            print_info("time {:.1f} --- calculation finished".format(time.time() - start_time), logger)
            
        for i, point in enumerate(points):
            if i % printEpochPeriod == 0:
                print_info("time {:.1f} --- {:d}-th point grouping... currently having {:d} groups".format(time.time() - start_time, i, len(groups)), logger)
            if mode == 'small_memory':
                nearest_group_index = self._determine_nearest_group2(point, groups, distTol)
            else:
                nearest_group_index = self._determine_nearest_group(i, groups, distMat, distTol)
            if nearest_group_index is None:
                # create new group
                if mode == 'small_memory':
                    groups.append([point])
                else:
                    groups.append([i])
                group_assignment.append(group_index)
                group_index += 1
            else:
                group_assignment.append(nearest_group_index)
                if mode == 'small_memory':
                    groups[nearest_group_index].append(point)
                else:
                    groups[nearest_group_index].append(i)
        if points.is_cuda:
            return torch.cuda.LongTensor(group_assignment)
        return torch.LongTensor(group_assignment)
    
    def _get_pairwise_distance(self, points):
        
        if not points.is_cuda:
            points = points.cuda()
        
        # this may require large memory...
        #diff = points.unsqueeze(0) - points.unsqueeze(1)
        #return (diff*diff).sum(-1).sqrt()
        
        N = points.shape[0]
        distMat = torch.cuda.FloatTensor(N,N).zero_()
        for i in range(N):
            diff = points[i].unsqueeze(0) - points
            distMat[:,i] = (diff**2).sum(1)
        return distMat
    
    def _determine_nearest_group(self, idx, groups, distMat, distTol):
        nearest_group_index = None
        min_group_dist = float("inf")
        for index, group in enumerate(groups):
            distance_to_group = torch.min(distMat[idx, group]).item()
            if distance_to_group < min(distTol, min_group_dist):
                nearest_group_index = index
                min_group_dist = distance_to_group
        return nearest_group_index
    
    def _determine_nearest_group2(self, point, groups, distTol):
        nearest_group_index = None
        min_group_dist = float("inf")
        for index, group in enumerate(groups):
            distance_to_group = self._distance_to_group(point, group)
            if distance_to_group < min(distTol, min_group_dist):
                nearest_group_index = index
                min_group_dist = distance_to_group
        return nearest_group_index
    
    def _distance_to_group(self, point, group):
        group = torch.vstack(group)
        diff = point.unsqueeze(0) - group
        dist = torch.sqrt(torch.sum(diff*diff, dim=1))
        return torch.min(dist)

class PointGrouper_sph_ambient(PointGrouper):
    def _get_pairwise_distance(self, points):
        if not points.is_cuda:
            points = points.cuda()

        # this may require large memory...
        #return pairwise_distance(points)
        
        N = points.shape[0]
        distMat = torch.cuda.FloatTensor(N,N).zero_()
        for i in range(N):
            temp = (points[i].unsqueeze(0) * points).sum(1)
            temp[temp > 1] = 1
            temp[temp < -1] = -1
            distMat[:,i] = torch.acos(temp)    
        return distMat
    
    def _distance_to_group(self, point, group):
        group = torch.vstack(group)
        return torch.min(distance(point.unsqueeze(0), group))
    
class PointGrouper_sph(PointGrouper):
    def _get_pairwise_distance(self, points):
        if not points.is_cuda:
            points = points.cuda()
            
        # this may require large memory...
        #return getPairwiseDist_torch(points)
        
        N = points.shape[0]
        distMat = torch.cuda.FloatTensor(N,N).zero_()
        pos = getPos_torch(points)
        for i in range(N):
            temp = (pos[i].unsqueeze(0) * pos).sum(1)
            temp[temp > 1] = 1
            temp[temp < -1] = -1
            distMat[:,i] = torch.acos(temp)    
        return distMat
        
        
    def _distance_to_group(self, point, group):
        group = torch.vstack(group)
        return torch.min(getDist_torch(point.unsqueeze(0), group))
