# -*- coding: utf-8 -*-

import warnings
import numpy as np
import pandas as pd
import multiprocessing as mp

from functools import partial
from skimage import morphology
from sklearn.cluster import DBSCAN
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.cluster.hierarchy import linkage
from scipy.spatial import distance, Delaunay
from scipy.special import sph_harm
from dipy.core.sphere import HemiSphere, Sphere, disperse_charges
from dipy.core.geometry import sphere2cart, cart2sphere


###############################################################################
''' TOOLS '''
###############################################################################

def sphere_intersection_poolhelper(instance_indices, point_coords=None, radii=None):
    
    # get radiii, positions and distance
    r1 = radii[instance_indices[0]]
    r2 = radii[instance_indices[1]]
    p1 = point_coords[instance_indices[0]]
    p2 = point_coords[instance_indices[1]]
    d = np.sqrt(np.sum((np.array(p1)-np.array(p2))**2))
    
    # calculate individual volumes
    vol1 = 4/3*np.pi*r1**3
    vol2 = 4/3*np.pi*r2**3
    
    # calculate intersection of volumes
    
    # Smaller sphere inside the bigger sphere
    if d <= np.abs(r1-r2): 
        intersect_vol = 4/3*np.pi*np.minimum(r1,r2)**3
    # No intersection at all
    elif d > r1+r2:
        intersect_vol = 0
    # Partially intersecting spheres
    else:
        intersect_vol = np.pi * (r1 + r2 - d)**2 * (d**2 + 2*d*r2 - 3*r2**2 + 2*d*r1 + 6*r2*r1 - 3*r1**2) / (12*d) 
        
    return (intersect_vol, vol1, vol2)



def harmonic_non_max_suppression(point_coords, point_probs, shape_descriptors, overlap_thresh=0.5, dim_scale=(1,1,1), num_kernel=1, **kwargs):

    if len(point_coords)>7500:
        
        print('Too many points, aborting NMS')
        nms_coords = point_coords[:2000]
        nms_probs = point_probs[:2000]
        nms_shapes = shape_descriptors[:2000]
    
    elif len(point_coords)>1:
        
        dim_scale = [d/np.min(dim_scale) for d in dim_scale]
        point_coords_uniform = []
        for point_coord in point_coords:
            point_coords_uniform.append(tuple([p*d for p,d in zip(point_coord,dim_scale)]))
        
        # calculate upper and lower volumes
        r_upper = [r.max() for r in shape_descriptors]
        r_lower = [r.min() for r in shape_descriptors]
        
        # Calculate intersections of lower and upper spheres
        #instance_indices = list(itertools.combinations(range(len(point_coords)), r=2))
        r_max = np.max(r_upper)
        instance_indices = [ (i, j) for i in range(len(point_coords))
                                    for j in range(i+1, len(point_coords))
                                    if np.sum(np.sqrt(np.abs(np.array(point_coords[i])-np.array(point_coords[j])))) < r_max*2 ]
        with mp.Pool(processes=num_kernel) as p:
            vol_upper = p.map(partial(sphere_intersection_poolhelper, point_coords=point_coords_uniform, radii=r_upper), instance_indices)
            vol_lower = p.map(partial(sphere_intersection_poolhelper, point_coords=point_coords_uniform, radii=r_lower), instance_indices)
        
        instances_keep = np.ones((len(point_coords),), dtype=np.bool)
                
        # calculate overlap measure
        for inst_idx, v_up, v_low in zip(instance_indices, vol_upper, vol_lower):
            
            # average intersection with smaller sphere
            overlap_measure_up = v_up[0] / np.minimum(v_up[1],v_up[2])
            overlap_measure_low = v_low[0] / np.minimum(v_low[1],v_low[2])
            overlap_measure = (overlap_measure_up+overlap_measure_low)/2
                        
            if overlap_measure > overlap_thresh:
                # Get min and max probable indice
                inst_min = inst_idx[np.argmin([point_probs[i] for i in inst_idx])]
                inst_max = inst_idx[np.argmax([point_probs[i] for i in inst_idx])]          
                
                # If there already was an instance with higher probability, don't add the current "winner"
                if instances_keep[inst_max] == 0: 
                    # Mark both as excluded
                    instances_keep[inst_min] = 0 
                    instances_keep[inst_max] = 0
                else:                    
                    # Exclude the loser
                    instances_keep[inst_min] = 0 
                    #instances_keep[inst_max] = 1
          
        # Mark remaining indices for keeping
        #instances_keep = instances_keep != -1 
        
        nms_coords = [point_coords[i] for i,v in enumerate(instances_keep) if v]
        nms_probs = [point_probs[i] for i,v in enumerate(instances_keep) if v]
        nms_shapes = [shape_descriptors[i] for i,v in enumerate(instances_keep) if v]
    
    else:
        nms_coords = point_coords
        nms_probs = point_probs
        nms_shapes = shape_descriptors
        
    return nms_coords, nms_probs, nms_shapes



# hierarchical clustering of point detections
def agglomerative_clustering(point_coords, point_probs, shape_descriptors=None, max_dist=6, max_points=20000, dim_scale=(1,1,1), use_nms=False, **kwargs):
        
    if len(point_coords)>max_points:
        warnings.warn('Too many objects detected! Due to memory limitations, the clustering will be aborted.', stacklevel=0)
        cluster_coords = point_coords
        cluster_probs = point_probs
        cluster_shapes = shape_descriptors
    
    elif len(point_coords)>1:
    
        # scale coordinates to a uniform voxel grid
        dim_scale = [d/np.min(dim_scale) for d in dim_scale]
        point_coords_uniform = []
        for point_coord in point_coords:
            point_coords_uniform.append(tuple([p*d for p,d in zip(point_coord,dim_scale)])) 
            
        # calculate linkages
        links = linkage(point_coords_uniform, method='ward', metric='euclidean')
        
        # define recursive cluster merging
        def recursive_cluster_merging(links, points, clusters_used, cluster_id, num_points):
            clusters_used.append(cluster_id)
            for i in range(2):
                if links[cluster_id,i] < num_points:
                    points.append(int(links[cluster_id,i]))
                else:
                    points, clusters_used = recursive_cluster_merging(links, points, clusters_used, int(links[cluster_id,i]%num_points), num_points)
            return points, clusters_used
        
        # start from the top of the dendogram and merge clusters until all points 
        # are assigned to their corresponding cluster
        clusters = []
        clusters_used = np.zeros((links.shape[0],), dtype='bool')
        for cluster_id in range(links.shape[0]-1, -1, -1):
            if links[cluster_id,2] <= max_dist and not clusters_used[cluster_id]:
                points, id_used = recursive_cluster_merging(links, [], [], cluster_id, len(point_coords))
                clusters.append(points)
                clusters_used[id_used] = True
            
        # sanity check
        clustered_points = [idx for c in clusters for idx in c]
        assert len(clustered_points) == len(np.unique(clustered_points)), \
                'Some points are assigned to multiple clusters'    
        
        # determine centroids based on determined clusters and probabilities
        cluster_coords = []
        cluster_probs = []
        cluster_shapes = []
        for cluster in clusters:
            cluster_prob = [point_probs[c] for c in cluster]
            # determine weighted cluster centroid
            cluster_points = [point_coords[c] for c in cluster]
            if use_nms:
                cluster_centroid = cluster_points[np.argmax(cluster_prob)]
            else:
                cluster_centroid = np.average(cluster_points, axis=0, weights=cluster_prob)
                cluster_centroid = [int(np.round(p)) for p in cluster_centroid]
            # determine cluster descriptor
            if not shape_descriptors is None:
                cluster_descriptors = [shape_descriptors[c] for c in cluster]
                if use_nms:
                    cluster_descriptors = cluster_descriptors[np.argmax(cluster_prob)]
                else:
                    cluster_descriptors = np.average(cluster_descriptors, axis=0, weights=cluster_prob)
                cluster_shapes.append(cluster_descriptors)
            # determine cluster certainity
            if use_nms:
                cluster_prob = cluster_prob[np.argmax(cluster_prob)]
            else:
                cluster_prob = np.average(cluster_prob, axis=0, weights=cluster_prob)
            # add cluster
            cluster_coords.append(tuple(cluster_centroid))
            cluster_probs.append(float(cluster_prob))
                
            
        # add points, which were not clustered
        for idx in list(set(range(len(point_coords))) - set(clustered_points)):
            cluster_coords.append(point_coords[idx])
            cluster_probs.append(float(point_probs[idx]))
            if not shape_descriptors is None:
                cluster_shapes.append(shape_descriptors[idx])
            
    else:
        cluster_coords = point_coords
        cluster_probs = point_probs
        cluster_shapes = shape_descriptors
        
    if not shape_descriptors is None:
        return cluster_coords, cluster_probs, cluster_shapes
    else:
        return cluster_coords, cluster_probs
    
    
    
# density-based clustering of point detections    
def dbscan_clustering(point_coords, point_probs, shape_descriptors=None, max_dist=15, min_count=1, max_points=10000, dim_scale=(1,1,1)):
    
    if len(point_coords)>max_points:
        warnings.warn('Too many objects detected! Due to memory limitations, the clustering will be aborted.', stacklevel=0)
        cluster_coords = point_coords
        cluster_probs = point_probs
        cluster_shapes = shape_descriptors
    
    elif len(point_coords)>1:
        
        # scale coordinates to a uniform voxel grid
        dim_scale = [d/np.min(dim_scale) for d in dim_scale]
        point_coords_uniform = []
        for point_coord in point_coords:
            point_coords_uniform.append(tuple([p*d for p,d in zip(point_coord,dim_scale)]))
        
        # calculate linkages
        clustering = DBSCAN(eps=max_dist, min_samples=min_count).fit_predict(point_coords_uniform)
     
        # determine centroids based on determined clusters and probabilities
        cluster_coords = []
        cluster_probs = []
        cluster_shapes = []
        for cluster_idx in np.unique(clustering):
            
            cluster = np.where(clustering==cluster_idx)[0]
            
            cluster_prob = [point_probs[c] for c in cluster]
            # determine weighted cluster centroid
            cluster_points = [point_coords[c] for c in cluster]
            cluster_centroid = np.average(cluster_points, axis=0, weights=cluster_prob)
            cluster_centroid = [int(np.round(p)) for p in cluster_centroid]
            # determine cluster descriptor
            if not shape_descriptors is None:
                cluster_descriptors = [shape_descriptors[c] for c in cluster]
                cluster_descriptors = np.average(cluster_descriptors, axis=0, weights=cluster_prob)
                cluster_shapes.append(cluster_descriptors)
            # determine cluster certainity
            cluster_prob = np.average(cluster_prob, axis=0, weights=cluster_prob)
            # add cluster
            cluster_coords.append(tuple(cluster_centroid))
            cluster_probs.append(float(cluster_prob))
            
    else:
        cluster_coords = point_coords
        cluster_probs = point_probs
        cluster_shapes = shape_descriptors
        
    if not shape_descriptors is None:
        return cluster_coords, cluster_probs, cluster_shapes
    else:
        return cluster_coords, cluster_probs



def scatter_3d(coords1, coords2, coords3, cartesian=True):    
    # (x, y, z) or (r, theta, phi)
    
    if not cartesian:
        coords1, coords2, coords3 = sphere2cart(coords1, coords2, coords3)
    
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter(coords1, coords2, coords3, depthshade=True)
    plt.show()
    
    
def triplot_3d(coords1, coords2, coords3, tri_mesh, cartesian=True):    
    # (x, y, z) or (r, theta, phi)
    
    if not cartesian:
        coords1, coords2, coords3 = sphere2cart(coords1, coords2, coords3)
    
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot_trisurf(coords1, coords2, coords3, triangles=tri_mesh.simplices.copy())
    plt.show()



def get_sampling_sphere(num_sample_points=500, num_iterations=5000, plot_sampling=False):
    
    # get angular sampling
    theta = np.pi * np.random.rand(num_sample_points//2)
    phi = 2 * np.pi * np.random.rand(num_sample_points//2) 
    
    # get initial and updated hemisphere
    hsph_initial = HemiSphere(theta=theta, phi=phi) 
    hsph_updated, potential = disperse_charges(hsph_initial, num_iterations)
    
    # get the full sphere
    sph = Sphere(xyz=np.vstack((hsph_updated.vertices, -hsph_updated.vertices)))
    
    # plot the resulting sample distribution
    if plot_sampling:
        scatter_3d(sph.x, sph.y, sph.z, cartesian=True)        
    
    return sph



def samples2delaunay(sample_points, cartesian=True):
    # sample points as [r,theta,phi] list
    
    if not cartesian:
        sample_points = sphere2cart(sample_points[0], sample_points[1], sample_points[2])
    
    sample_points = np.transpose(np.array(sample_points))
    delaunay_tri = Delaunay(sample_points)
    
    return delaunay_tri



def delaunay2indices(delaunay_tri, idx):
        
    # -1 indicates indices outside the region
    voxel_idx = np.nonzero(delaunay_tri.find_simplex(idx)+1)
        
    return voxel_idx
    


def spherical_instance_sampling(data, theta_phi_sampling, bg_values=[0], verbose=False, **kwargs):
    # theta_phi_sampling as [(theta_1,phi_1), (theta_2,phi_2),...] list
    
    # get all instances
    instances = np.unique(data)
    # exclude background instances    
    instances = list(set(instances)-set(bg_values))
    
    r_sampling = []
    centroids = []
    
    for num_instance, instance in enumerate(instances):
        
        if verbose:
            print('\r'*22+'Progress {0:0>2d}% ({1:0>3d}/{2:0>3d})'.format(int(num_instance/len(instances)*100), num_instance, len(instances)), end='\r')
                
        # get the mask of the current cell
        instance_mask = data==instance
        # extract the boundary of the current cell
        instance_boundary = np.logical_xor(instance_mask, morphology.binary_erosion(instance_mask))
        # ensure there are no holes created at the image boundary
        instance_boundary[...,0] = instance_mask[...,0]
        instance_boundary[...,-1] = instance_mask[...,-1]
        instance_boundary[:,0,:] = instance_mask[:,0,:]
        instance_boundary[:,-1,:] = instance_mask[:,-1,:]
        instance_boundary[0,...] = instance_mask[0,...]
        instance_boundary[-1,...] = instance_mask[-1,...]
        
        # get coordinates and centroid of the current cell
        mask_coords = np.nonzero(instance_boundary)
        centroid = np.array([np.mean(dim_coords) for dim_coords in mask_coords])
        centroids.append(centroid)
        # set centroid as coordinate origin and get spherical coordinates
        mask_coords = [dim_coords-c for dim_coords,c in zip(mask_coords, centroid)] 
        r_mask, theta_mask, phi_mask = cart2sphere(*mask_coords)
        
        # find closest matches to each sampling point
        distances = distance.cdist(theta_phi_sampling, list(zip(theta_mask, phi_mask)), metric='euclidean', p=1) #minkowski??
        closest_matches = np.argmin(distances, axis=1)
        r_sampling.append(r_mask[closest_matches])
                
    return instances, r_sampling, centroids
        
    


class sampling2harmonics():
    
    def __init__(self, sh_order, theta_phi_sampling, lb_lambda=0.006):
        super(sampling2harmonics, self).__init__()
        self.sh_order = sh_order
        self.theta_phi_sampling = theta_phi_sampling
        self.lb_lambda = lb_lambda
        self.num_samples = len(theta_phi_sampling)
        self.num_coefficients = np.int((self.sh_order+1)**2)
        
        b = np.zeros((self.num_samples, self.num_coefficients))
        l = np.zeros((self.num_coefficients, self.num_coefficients))
        
        for num_sample in range(self.num_samples):
            num_coefficient = 0
            for num_order in range(self.sh_order+1):
                for num_degree in range(-num_order, num_order+1):
                    
                    theta = theta_phi_sampling[num_sample][0]
                    phi = theta_phi_sampling[num_sample][1]
                    
                    y = sph_harm(np.abs(num_degree), num_order, phi, theta)
                                
                    if num_degree < 0:
                        b[num_sample, num_coefficient] = np.real(y) * np.sqrt(2)
                    elif num_degree == 0:
                        b[num_sample, num_coefficient] = np.real(y)
                    elif num_degree > 0:
                        b[num_sample, num_coefficient] = np.imag(y) * np.sqrt(2)
        
                    l[num_coefficient, num_coefficient] = self.lb_lambda * num_order ** 2 * (num_order + 1) ** 2
                    num_coefficient += 1
                    
        b_inv = np.linalg.pinv(np.matmul(b.transpose(), b) + l)
        self.convert_mat = np.matmul(b_inv, b.transpose()).transpose()
        
    def convert(self, r_sampling):
        converted_samples = []
        for r_sample in r_sampling:
            r_converted = np.matmul(r_sample[np.newaxis], self.convert_mat)
            converted_samples.append(np.squeeze(r_converted))
        return converted_samples
            

            
    
class harmonics2sampling():
    
    def __init__(self, sh_order, theta_phi_sampling):
        super(harmonics2sampling, self).__init__()
        self.sh_order = sh_order
        self.theta_phi_sampling = theta_phi_sampling
        self.num_samples = len(theta_phi_sampling)
        self.num_coefficients = np.int((self.sh_order+1)**2)
        
        convert_mat = np.zeros((self.num_coefficients, self.num_samples))
        
        for num_sample in range(self.num_samples):
            num_coefficient = 0
            for num_order in range(self.sh_order+1):
                for num_degree in range(-num_order, num_order+1):
                    
                    theta = theta_phi_sampling[num_sample][0]
                    phi = theta_phi_sampling[num_sample][1]
                    
                    y = sph_harm(np.abs(num_degree), num_order, phi, theta)
                                
                    if num_degree < 0:
                        convert_mat[num_coefficient, num_sample] = np.real(y) * np.sqrt(2)
                    elif num_degree == 0:
                        convert_mat[num_coefficient, num_sample] = np.real(y)
                    elif num_degree > 0:
                        convert_mat[num_coefficient, num_sample] = np.imag(y) * np.sqrt(2)
        
                    num_coefficient += 1
                    
        self.convert_mat = convert_mat
        
    def convert(self, r_harmonic):
        converted_harmonics = []
        for r_sample in r_harmonic:
            r_converted = np.matmul(r_sample[np.newaxis], self.convert_mat)
            converted_harmonics.append(np.squeeze(r_converted))
        return converted_harmonics
            
    
    
def scale_detections_csv(filelist, x_scale=1, y_scale=1, z_scale=1, **kwargs):
    
    for filepath in filelist:
        with open(filepath, 'r') as fh:
            data = pd.read_csv(fh, sep=';')
        
        data['xpos'] = data['xpos'].div(x_scale)
        data['ypos'] = data['ypos'].div(y_scale)
        data['zpos'] = data['zpos'].div(z_scale)
        
        data = data.astype(int)
        
        data.to_csv(filepath, sep=';', index_label='id')
            
            
