import sys
import numpy as np
import torch
import time
from util import *
from Pn_util import *


class PointGrouper(object):
    def group_points(self, points, printEpochPeriod = 100, loggingFileName = None, mode = 'small_memory', distTol = 0.1):
        if loggingFileName is None:
            logger = None
        else:
            logger = set_logger(loggingFileName)
        group_assignment = []
        groups = []
        group_index = 0
        start_time = time.time()
        distMat = None
        quantities_at_points = None
        
        if mode == 'small_memory':
            quantities_at_points = self._get_quantities(points)
        else:
            print_info("time {:.1f} --- calculate pairwise distances using gpu".format(time.time() - start_time), logger)
            distMat = self._get_pairwise_distance(points).cpu()
            print_info("time {:.1f} --- calculation finished".format(time.time() - start_time), logger)
        
        for i, point in enumerate(points):
            if i % printEpochPeriod == 0:
                print_info("time {:.1f} --- {:d}-th point grouping... currently having {:d} groups".format(time.time() - start_time, i, len(groups)), logger)
            if mode == 'small_memory':
                if quantities_at_points is not None:
                    nearest_group_index = self._determine_nearest_group2(point, groups, quantities_at_points[i], distTol)
                else:
                    nearest_group_index = self._determine_nearest_group2(point, groups, None, distTol)
            else:
                nearest_group_index = self._determine_nearest_group(i, groups, distMat, distTol)
            if nearest_group_index is None:
                # create new group
                if mode == 'small_memory':
                    groups.append([point])
                else:
                    groups.append([i])
                group_assignment.append(group_index)
                group_index += 1
            else:
                group_assignment.append(nearest_group_index)
                if mode == 'small_memory':
                    groups[nearest_group_index].append(point)
                else:
                    groups[nearest_group_index].append(i)
        if points.is_cuda:
            return torch.cuda.LongTensor(group_assignment)
        return torch.LongTensor(group_assignment)
    
    def _get_quantities(self, points):
        return
    
    def _get_pairwise_distance(self, points):
        
        if not points.is_cuda:
            points = points.cuda()
        
        # this may require large memory...
        #diff = points.unsqueeze(0) - points.unsqueeze(1)
        #return (diff*diff).sum(-1).sqrt()
        
        N = points.shape[0]
        distMat = torch.cuda.FloatTensor(N,N).zero_()
        for i in range(N):
            diff = points[i].unsqueeze(0) - points
            distMat[:,i] = (diff**2).sum(1)
        return distMat
    
    def _determine_nearest_group(self, idx, groups, distMat, distTol):
        nearest_group_index = None
        min_group_dist = float("inf")
        for index, group in enumerate(groups):
            distance_to_group = torch.min(distMat[idx, group]).item()
            if distance_to_group < min(distTol, min_group_dist):
                nearest_group_index = index
                min_group_dist = distance_to_group
        return nearest_group_index
    
    def _determine_nearest_group2(self, point, groups, quantities_at_point, distTol):
        nearest_group_index = None
        min_group_dist = float("inf")
        for index, group in enumerate(groups):
            distance_to_group = self._distance_to_group(point, group, quantities_at_point)
            if distance_to_group < min(distTol, min_group_dist):
                nearest_group_index = index
                min_group_dist = distance_to_group
        return nearest_group_index
    
    def _distance_to_group(self, point, group, quantities_at_point):
        group = torch.vstack(group)
        diff = point.unsqueeze(0) - group
        dist = torch.sqrt(torch.sum(diff*diff, dim=1))
        return torch.min(dist)
    
class PointGrouper_P_n(PointGrouper):
    def _get_quantities(self, points):
        _, points_invsqrt = get_sqrt_sym(points, returnInvAlso = True)
        return points_invsqrt
    
    def _get_pairwise_distance(self, points):
        if not points.is_cuda:
            points = points.cuda()
        
        # this may require large memory...
        #return pairwise_distance(points)
        
        N = points.shape[0]
        distMat = torch.cuda.FloatTensor(N,N).zero_()
        _, points_invsqrt = get_sqrt_sym(points, returnInvAlso = True)
        for i in range(N):
            distMat[:,i] = squared_distance(points, points[i].unsqueeze(0), X_invsqrt = points_invsqrt).sqrt()
        return distMat
    
    def _distance_to_group(self, point, group, quantities_at_point):
        return torch.min(squared_distance(point.unsqueeze(0), torch.stack(group, dim=0), X_invsqrt = quantities_at_point.unsqueeze(0)).sqrt())