# -*- coding: utf-8 -*-


import numpy as np
import itertools
import multiprocessing as mp

from functools import partial
from skimage import morphology, measure
from scipy.ndimage.morphology import distance_transform_edt

from util.utils import spherical_instance_sampling, samples2delaunay, delaunay2indices



def instances2multiclass(data, bg_values=[0], size_thresh=20, **kwargs):
    
    data = data.astype(np.uint16)
    
    #labels: 1=bg, 2=centroid, 3=membrane
    data_multiclass = np.zeros(data.shape, dtype=np.uint8)
    
    # get all instances
    instances = np.unique(data)
    # exclude background instance
    instances = list(set(instances)-set(bg_values))
    
    # save background mask and exclude background from the image
    for bg_value in bg_values:
        data_multiclass[data==bg_value] = 1
        data[data==bg_value] = 0        
    
    # get membrane segmentation
    instance_mask = data - morphology.erosion(data)
    data_multiclass[instance_mask>0] = 3
    
    # get centroids for each instance
    # make sure to exclude sampling artifacts at instance borders
    instance_mask = morphology.dilation(instance_mask)
    data[instance_mask>0] = 0
    regions = measure.regionprops(data[...,0])
    for props in regions:
        if props.area > size_thresh:
            region_centroid = props.centroid
            region_centroid = tuple([int(np.round(c)) for c in region_centroid])
            data_multiclass[region_centroid] = 2        
    data_multiclass = morphology.dilation(data_multiclass)
    
    
    return data_multiclass




def instances2distmap(data, bg_values=[0], saturation_dist=1, dist_activation='tanh', **kwargs):
    
    data = data.astype(np.uint16)
    
    # extract background value
    data_bg = np.zeros(data.shape, dtype=np.bool)
    for bg_value in bg_values:
        data_bg = np.logical_or(data_bg, data==bg_value)
    data_bg = distance_transform_edt(data_bg)
    
    # get membrane segmentation
    data_map = data - morphology.erosion(data)
    data_map = distance_transform_edt(data_map<=0)
    
    # create saturated distance map
    if 'tanh' in dist_activation:
        data_map = np.multiply((data_bg<=0),data_map) - np.multiply((data_bg>0),data_bg)
        data_map = np.tanh(np.divide(data_map,saturation_dist))
    else:
        data_map = data_map/saturation_dist
    
    return data_map
    



def instances2indices(data, bg_values=[0], **kwargs):
    
    data = data.astype(np.uint16)
    
    # get all instances
    instances = np.unique(data)
    # exclude background instances    
    instances = list(set(instances)-set(bg_values))
    # exclude background from the image
    for bg_value in bg_values:
        data[data==bg_value] = 0
    
    regions = measure.regionprops(data[...,0])
    centroids = []
    for props in regions:
        centroids.append(props.centroid)
        
    return centroids




def instances2label(data, bg_values=[0], verbose=False, **kwargs):
    
    is_instance = np.array(data>=1).any()
    
    return is_instance[...,np.newaxis]




def instances2harmonicmask(data, s2h_converter, probs=None, shape=(112,112,112), cell_size=(8,8,8), dets_per_region=3, bg_values=[0], verbose=False, **kwargs):
    '''
    An output mask for 3 possible detections in 3D will have the following channels: 
        [p1,p2,...,x1,y1,z1,x2,y2,z2,...,d11,d12,d13,...,d21,d22,d23,...]
    '''
        
    # sample each instance
    instances, r_sampling, centroids = spherical_instance_sampling(data[...,0], s2h_converter.theta_phi_sampling, bg_values=bg_values, verbose=verbose)
    
    # if no probabilities are given, assume certain detections
    if probs is None:
        probs = [1,]*len(instances)
        
    # convert the sampling to harmonics
    r_harmonics = s2h_converter.convert(r_sampling)
    num_coefficients = s2h_converter.num_coefficients
            
    # create mask
    mask = np.zeros(tuple([int(s/c) for s,c in zip(shape,cell_size)])+((4+num_coefficients)*dets_per_region,), dtype=np.float)
    for idx,descriptor,prob in zip(centroids,r_harmonics,probs):
        # get current cell index and intra cell offset
        cell_idx = [int(i//c) for i,c in zip(idx,cell_size)]
        voxel_offset = [int(i%c)/(c-1) for i,c in zip(idx, cell_size)]
        # each cell is allowed to contain multiple objects (each [p,x,y,z])
        for num_det in range(dets_per_region):
            if mask[tuple(cell_idx)][num_det] == 0:
                # set probability information
                mask[tuple(cell_idx)][num_det] = prob
                # set positional information
                mask[tuple(cell_idx)][dets_per_region*1+num_det*3:dets_per_region*1+(num_det+1)*3] = voxel_offset
                # set shape information
                mask[tuple(cell_idx)][dets_per_region*4+num_det*num_coefficients:dets_per_region*4+(num_det+1)*num_coefficients] = descriptor
                break       
    return mask




def harmonicmask2sampling(harmonic_mask, h2s_converter, cell_size=(8,8,8), dets_per_region=3, thresh=0., convert2radii=True, positional_weighting=False, **kwargs):
    
    probs = []
    centroids = []
    shape_descriptors = []
    
    num_coefficients = h2s_converter.num_coefficients    
    trace_indices = itertools.product(*[range(s) for s in harmonic_mask.shape[:-1]])
    
    patch_size = tuple([cs*hi for cs,hi in zip(cell_size, harmonic_mask.shape[:-1])])
    
    for trace_idx in trace_indices:
        # extract detection information for the current region
        pred_info = harmonic_mask[trace_idx]
        for num_det in range(dets_per_region):
                       
            # extract shape information
            harmonic_descriptors = pred_info[dets_per_region*4+num_det*num_coefficients:dets_per_region*4+(num_det+1)*num_coefficients]
            # sanitycheck if there actually is a shape and not only a single noise point
            if np.count_nonzero(harmonic_descriptors) > 0: 
                                
                # extract positional information
                pos_info = pred_info[dets_per_region*1+num_det*3:dets_per_region*1+(num_det+1)*3]
                # reconstruct the position within the mask space
                # (position+offset)*cell_size 
                # restricted by 0 and the image size (=mask_shape*cell_size)
                mask_index = [np.clip(np.ceil((t+p)*(s)-1),0,s*ms-1).astype(np.int16) for t,p,s,ms in zip(trace_idx, pos_info, cell_size, harmonic_mask.shape[:-1])]
                                
                # extract confidence information
                prob = pred_info[num_det]
                # add positional weight (low at patch boundaries, weighted by tanh and scaled by the 8th of the patch size)
                if positional_weighting:
                    prob_weight = [np.minimum(np.tanh(mi/ps*8),np.tanh((ps-mi)/ps*8)) for mi,ps in zip(mask_index, patch_size)]
                    prob = prob * np.min(prob_weight)
                                        
                 # if a cell was detected, start localizing it
                if prob>thresh:
                    # append information
                    shape_descriptors.append(harmonic_descriptors)
                    centroids.append(tuple(mask_index))
                    probs.append(prob)
                
    if convert2radii:       
        shape_descriptors = h2s_converter.convert(shape_descriptors)
        
    return centroids, probs, shape_descriptors
                



def descriptors2image_poolhelper(descriptor, theta=None, phi=None, shape=(112,112,112), thresh=0, verbose=False):
                
        centroid = descriptor[0]
        prob = descriptor[1]
        shape_descriptor = descriptor[2]
                
        if prob >= thresh and np.count_nonzero(shape_descriptor) > 0:
            
            # Create object map
            max_object_extend = int(np.ceil(shape_descriptor.max())) 
            x,y,z = np.indices((2*max_object_extend+1,)*len(shape))
            idx = np.stack([x,y,z], axis=-1)
            idx[...,0] = idx[...,0]-max_object_extend
            idx[...,1] = idx[...,1]-max_object_extend
            idx[...,2] = idx[...,2]-max_object_extend

            # Get Delaunay triangulation and indices of voxels within the object
            delaunay_tri = samples2delaunay([shape_descriptor, theta, phi], cartesian=False)
            instance_indices = delaunay2indices(delaunay_tri, idx)
            
            # Adjust instance indices to the actual image position, considering the bounds
            instance_indices = tuple([np.array([np.maximum(0, np.minimum(shape[0]-1, i+int(np.round(centroid[0]))-max_object_extend)) for i in instance_indices[0]]),\
                                      np.array([np.maximum(0, np.minimum(shape[1]-1, i+int(np.round(centroid[1]))-max_object_extend)) for i in instance_indices[1]]),\
                                      np.array([np.maximum(0, np.minimum(shape[2]-1, i+int(np.round(centroid[2]))-max_object_extend)) for i in instance_indices[2]])])
        else:
             # Create emtpy instance
            instance_indices = ()
                
        return instance_indices
    
    

def descriptors2image(descriptors, theta_phi_sampling=None, shape=(112,112,112), thresh=0, verbose=False, num_kernel=4, **kwargs):
    '''
    descriptors need to be in shape [centroids, probs, radii_sampling]
    '''
    
    theta = [tps[0] for tps in theta_phi_sampling]
    phi = [tps[1] for tps in theta_phi_sampling]
    
    instance_mask = np.zeros(shape, dtype=np.uint16)    
        
    # Parallelize instance voxelization
    with mp.Pool(processes=num_kernel) as p:
        instance_list = p.map(partial(descriptors2image_poolhelper, theta=theta, phi=phi, shape=shape, thresh=thresh, verbose=verbose), list(zip(*descriptors)))
    
    # Fill the final instance mask
    for instance_count,instance in enumerate(instance_list):
        if len(instance)>0:
            instance_mask[instance] = instance_count+1 
    
    return instance_mask


    





