import pyshtools
import random
import torch
import matplotlib.pyplot as plt
import numpy as np

from skimage import io, segmentation, morphology
from sklearn.metrics import pairwise_distances_chunked
from scipy.ndimage import binary_erosion, binary_opening, binary_dilation, binary_fill_holes, convolve, binary_closing
from mpl_toolkits.mplot3d import Axes3D



# Global plotting definitions
SMALL_SIZE = 18
MEDIUM_SIZE = 20
BIGGER_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title



# cuda pairwise distance with torch, approximately halves the computation time for large images
def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    # Ensure diagonal is zero if x=y
    # if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist.sqrt(), 0.0, np.inf)



# Class for generation of a synthetic membrane mask
class SyntheticCellMembraneMasks:
    """A Class for generating synthetic cell membrandes, or loading them from a file"""

    def __init__(self, size=128, radius=None, roundness = None, cell_count = None, sphere_count=None, init_mode='exponentialfunction', cos_freq=2, variation_factor=8, boundary_mode='inner', n_jobs=1, use_cuda=True, split_factor=16, printpdf=False, filepath="", weight_range=0.15, init_points = None, init_weight=None, init_points_sphere=None, morph=5, init_sph=False, maxradius=False):
        """If a filepath is given loads the either the label image or the boundary image. Otherwise generates cell membrane mask with the given parameters.
        
        Parameters:
            size (int): defines the edge length of the output image.
            roundness (int): a factor that defines the smoothness of the surface of the sphere (if None is given it is chosen randomly from numbers from 3 to 6).
            radius (float): maximal radius of the resulting sphere
            cell_count (int): defines the number of randomly generated cells in the image (if None is given it defaults to 32 times the size).
            sphere
            sphere_count (int): defines the number of overlapping spheres (if None is given it is chosen randomly from 1 to 4). They overlap between 7/8 and 6/8 times the size.
            init_mode (string): how the power spectrum of the spherical harmonics should be initialized. With the cosine riples on the surface can be generated exponentialfunction_with_cosine | exponentialfunction | exponential
            cos_freq (float): defines with which frequency the riples with the cosine on the surface should oscilate.
            variation_factor (float): defines how much the surface of the sphere should vary compared to the radius. (Example 8: The Radius can vary maximal be 1/8 bigger than the radius of the ideal sphere)
            boundary_mode (string): boundary modes of skimage.segmentation.find_boundaries().
            n_jobs (int): number of jobs tocompute pairwise distance on cpu, ignored if use_cuda is true. Usage not recommended because it not necessarily speeds the process up.
            use_cuda (bool): use of the gpu (speeds up the process especially for large sizes) only used if a gpu is available.
            split_factor (int): multiplied with the size to determine the number of chunks that fit in the gpu memory ( default 16 fits inside ~6GB of gpu memory).
            filepath (string): If a filepath is given loads the either the label image or the boundary image.
            printpdf (bool): Print points and sphere to pdf if true.
            weight_range (float): to make borders with radius in the voronoi cells each cell gets an additional weight that divides the distance.
            morph (int): radius of the structering element to generate round cell edges at the border, dont filter if 0
            init_sph (bool): initialize sphere with spherical harmonics of init_points_sphere
            init_weight (array): initialization for the additional cell center weights
            init_points (array): initial center points if they should not be chosen randomly
            maxradius (bool): use the maxmal cell radius in the inner part as maximal distance for the outer voxel

            Images can be saved as boundary image with self.saveBoundaries(filepath) and label images with self.saveLabel(filepath) and can be accessed with self.boundary_image and self.label_image.
        """
        if filepath == "":
           # generate synthetic cell membranes
            self.generate(size, radius=radius, roundness = roundness, cell_count = cell_count, sphere_count=sphere_count, init_mode=init_mode, cos_freq=cos_freq, variation_factor=variation_factor,boundary_mode=boundary_mode, n_jobs=n_jobs, use_cuda=use_cuda, split_factor=split_factor, printpdf=printpdf, weight_range=weight_range, init_points = init_points, init_weight=init_weight, init_points_sphere = init_points_sphere, morph=morph, init_sph=init_sph, maxradius=maxradius)
        else:
            # load synthetic cell membranes
            image = io.imread(filepath)
            if np.unique(image).shape[0] > 2:
                self.label_image = io.imread(filepath)
                self.boundary_image = segmentation.find_boundaries(
                    self.label_image, mode=boundary_mode).astype(np.uint8)*255
            else:
                self.label_image = None
                self.boundary_image = None
            self.grid = None


    # Generation process
    def generate(self, gridsize, radius=None, roundness = None, cell_count = None, sphere_count=None, init_mode='exponentialfunction', cos_freq=2, variation_factor=8, boundary_mode='inner', n_jobs=1, use_cuda=True, split_factor=16, printpdf=False, weight_range=0.15, init_points = None, init_weight=None, init_points_sphere=None, morph=5, init_sph=False, maxradius=False):
        """
            Generate a grid with a deformed spheres by spherical harmonics and fill it with random points.
            From these points voronoi cells are generated that represent the synthetic cells.
        """
        if sphere_count is None:
            n_spheres = random.randint(1,4)
        else:
            n_spheres = sphere_count
        
        if radius is None:
            radius = gridsize/2

        # generate random cell centers
        if init_points is None:
            if cell_count is None:
                cell_count = int(gridsize**3*0.00006682326644) # example cell density in a 512x512x512 block, that are 8968 cells
            xrand = (np.random.rand(cell_count, 1)*gridsize).astype(int)-gridsize/2
            yrand = (np.random.rand(cell_count, 1)*gridsize).astype(int)-gridsize/2
            zrand = (np.random.rand(cell_count, 1)*gridsize).astype(int)-gridsize/2
        else:
            xrand = init_points[0]
            yrand = init_points[1]
            zrand = init_points[2]

        if init_points_sphere is None:
            # generate grid and get all point coordinates in it
            self.label_image = np.zeros((gridsize, gridsize, gridsize), dtype='uint16')
            xgrid = np.arange(0, gridsize)-gridsize/2
            ygrid = np.arange(0, gridsize)-gridsize/2
            zgrid = np.arange(0, gridsize)-gridsize/2
            xgrid, ygrid, zgrid = np.meshgrid(xgrid, ygrid, zgrid)

            xgrid = xgrid.flatten()
            ygrid = ygrid.flatten()
            zgrid = zgrid.flatten()
            for i in range(n_spheres):
                # create sphere
                if init_points_sphere is None:
                    r, t, p = self.__createsphere(gridsize, radius, roundness, init_mode=init_mode, cos_freq=cos_freq, variation_factor=variation_factor, deg=int((gridsize**3/cell_count)**(1/3)))
                    if i == 0:
                        # first one positioned at center
                        # which points to delete
                        delindrand = self.__delOutOfRange(r, t, p, xrand, yrand, zrand)
                        delindgrid = self.__delOutOfRange(r, t, p, xgrid.flatten(), ygrid.flatten(), zgrid.flatten())
                    else:
                        # randomly psition next sphere around the first
                        posx = random.choice((-1,1))*int(random.random()*gridsize/8+gridsize/4)
                        posy = random.choice((-1,1))*int(random.random()*gridsize/8+gridsize/4)
                        posz = random.choice((-1,1))*int(random.random()*gridsize/8+gridsize/4)
                        # which points to delete intersection with previous points
                        delindrand = np.intersect1d(delindrand, self.__delOutOfRange(r, t, p, xrand+posx, yrand+posy, zrand+posz))
                        delindgrid = np.intersect1d(delindgrid, self.__delOutOfRange(r, t, p, xgrid+posx, ygrid+posy, zgrid+posz))
                
                if printpdf:
                    # print sphere
                    fig = plt.figure()

                    x,y,z = self.__spher2cart(r.flatten(),t,p)

                    ax = fig.add_subplot( 132, projection='3d')

                    ax.plot_surface(-x, -y, -z, rstride=1, cstride=1, color='#00549F')

                    ax.set_xlim((-gridsize//2,gridsize//2))
                    ax.set_ylim((-gridsize//2,gridsize//2))
                    ax.set_zlim((-gridsize//2,gridsize//2))
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_zticklabels([])
                    ax.set_axis_off()
        else:
            self.label_image = np.zeros(init_points_sphere.shape, dtype='uint16')
            gridsize = np.max(init_points_sphere.shape)
            if init_sph:
                phi = np.linspace(0, 2*np.pi, gridsize*4)
                theta = np.linspace(0, np.pi, gridsize*4, endpoint=True)
                bint = theta[1:len(theta)]-(theta[2]-theta[1])/2
                binp = phi[1:len(phi)]-(phi[2]-phi[1])/2
                r_map = np.ones((len(theta),len(phi)))

                instance_boundary = np.logical_xor(init_points_sphere, morphology.binary_erosion(init_points_sphere))
                # ensure there are no holes created at the image boundary
                instance_boundary[...,0] = init_points_sphere[...,0]
                instance_boundary[...,-1] = init_points_sphere[...,-1]
                instance_boundary[:,0,:] = init_points_sphere[:,0,:]
                instance_boundary[:,-1,:] = init_points_sphere[:,-1,:]
                instance_boundary[0,...] = init_points_sphere[0,...]
                instance_boundary[-1,...] = init_points_sphere[-1,...]
                x,y,z = np.where(instance_boundary)

                xgrid = np.arange(0, init_points_sphere.shape[0])-init_points_sphere.shape[0]/2
                ygrid = np.arange(0, init_points_sphere.shape[1])-init_points_sphere.shape[1]/2
                zgrid = np.arange(0, init_points_sphere.shape[2])-init_points_sphere.shape[2]/2
                xgrid, ygrid, zgrid = np.meshgrid(xgrid, ygrid, zgrid)

                xgrid = xgrid.flatten()
                ygrid = ygrid.flatten()
                zgrid = zgrid.flatten()

                center = [np.mean(x),np.mean(y),np.mean(z)]
                x_new = x - center[0]
                y_new = y - center[1]
                z_new = z - center[2]
                r, t, p = self.__cart2spher(x_new,y_new, z_new)
                binst = np.digitize(t,bint)
                binsp = np.digitize(p,binp)
                r_map = r_map * np.mean(r)
                for k,(i,j) in enumerate(zip(binst,binsp)):
                    r_map[i,j] = r[k]
                t = theta
                p = phi
                r = r_map
                self.grid = pyshtools.SHGrid.from_array(r_map)
                # get spherical coordinate grid from the spherical harmonic grid
                r, t, p = self.__grid2spher(self.grid, 1, variation_factor=0)
                delindgrid = self.__delOutOfRange(r, t, p, xgrid, ygrid, zgrid)
                xgrid = np.delete(xgrid, delindgrid) + init_points_sphere.shape[0]/2
                ygrid = np.delete(ygrid, delindgrid) + init_points_sphere.shape[1]/2
                zgrid = np.delete(zgrid, delindgrid) + init_points_sphere.shape[2]/2
            else:
                xgrid, ygrid, zgrid = np.where(init_points_sphere)
                xgrid = xgrid.astype(np.float64)
                ygrid = ygrid.astype(np.float64)
                zgrid = zgrid.astype(np.float64)

        if printpdf:
            # print random points
            ax = fig.add_subplot( 131, projection='3d')

            ax.scatter(xrand, yrand, zrand, color='#00549F')

            ax.set_xlim((-gridsize//2,gridsize//2))
            ax.set_ylim((-gridsize//2,gridsize//2))
            ax.set_zlim((-gridsize//2,gridsize//2))
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_zticklabels([])
            ax.set_axis_off()
        # delete points not in spheres
        if init_points_sphere is None:
            if maxradius:
                xgrid_outside = xgrid[delindgrid]
                ygrid_outside = ygrid[delindgrid]
                zgrid_outside = zgrid[delindgrid]
            xgrid = np.delete(xgrid, delindgrid)
            ygrid = np.delete(ygrid, delindgrid)
            zgrid = np.delete(zgrid, delindgrid)
            xrand = np.delete(xrand, delindrand)
            yrand = np.delete(yrand, delindrand)
            zrand = np.delete(zrand, delindrand)
        if printpdf:
            # print random points
            ax = fig.add_subplot( 133, projection='3d')

            ax.scatter(xrand, yrand, zrand, color='#00549F')

            ax.set_xlim((-gridsize//2,gridsize//2))
            ax.set_ylim((-gridsize//2,gridsize//2))
            ax.set_zlim((-gridsize//2,gridsize//2))
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_zticklabels([])
            ax.set_axis_off()

            fig.set_size_inches(12, 4)
            fig.tight_layout(rect=[0.03, 0.03, 0.95, 0.95])
            fig.savefig('SphereGenPlot.pdf')

        # stack points
        randpoints = np.transpose(np.vstack((xrand, yrand, zrand)))
        gridpoints = np.transpose(np.vstack((xgrid, ygrid, zgrid)))
        
        if maxradius:
            gridpointsOutside = np.transpose(np.vstack((xgrid_outside, ygrid_outside, zgrid_outside)))
        
        if init_weight is None:
            weight = np.random.rand((len(randpoints)))*(weight_range)+(1+weight_range)
        else:
            weight = init_weight.astype(np.float64)
        # it is faster on the gpu
        if use_cuda and torch.cuda.is_available():
            with torch.no_grad():
                weight = torch.from_numpy(weight[np.newaxis,:]).cuda()
                randpoints_gpu = torch.from_numpy(randpoints).cuda()
                list_ind = []
                list_val = []
                if maxradius:
                    list_val1 = []
                    list_ind1 = []
                # split data ito a lot of chunks to fit inside the memory
                dat = np.array_split(gridpoints, gridsize*split_factor)
                for i in dat:
                    points_cuda = torch.from_numpy(i).cuda()
                    dist = pairwise_distances(points_cuda, y=randpoints_gpu)
                    values_inside, indices = (dist).min(1)
                    values_inside, indices_inside = (dist/weight).min(1)
                    list_val.append(values_inside.cpu().numpy())
                    list_ind.append(indices_inside.cpu().numpy())
                if maxradius:
                    dat = np.array_split(gridpointsOutside, gridsize*split_factor)
                    for i in dat:
                        points_cuda = torch.from_numpy(i).cuda()
                        dist = pairwise_distances(points_cuda, y=randpoints_gpu)
                        values_outside, indices = (dist).min(1)
                        values_outside, indices_outside = (dist/weight).min(1)
                        list_ind1.append(indices_outside.cpu().numpy())
                        list_val1.append(values_outside.cpu().numpy())
                    indices_outside = np.hstack(list_ind1)
                    values_outside = np.hstack(list_val1)
                values_inside = np.hstack(list_val)
                indices_inside = np.hstack(list_ind)
        else:
            # use cpu with n_jobs
            get_ind = lambda D_chunk, start : np.argmin(D_chunk/weight[np.newaxis,:], 1)
            indice_generator = pairwise_distances_chunked(gridpoints, Y=randpoints, reduce_func=get_ind, n_jobs=n_jobs)
            indices_inside = np.hstack([i for i in indice_generator])

        if maxradius:
            uniqueind = np.unique(np.hstack([indices_inside,indices_outside]))
        else:
            uniqueind = np.unique(indices_inside)
        labeldict = dict(zip(uniqueind.tolist(), range(1, len(uniqueind)+1)))
        indices_inside = [labeldict[i] for i in indices_inside]
        if init_points_sphere is None:
            xgrid = xgrid+gridsize/2
            ygrid = ygrid+gridsize/2
            zgrid = zgrid+gridsize/2
        # create label image
        for i in range(len(xgrid)):
            self.label_image[np.floor(xgrid[i]).astype(np.uint), np.floor(
                ygrid[i]).astype(np.uint), np.floor(zgrid[i]).astype(np.uint)] = indices_inside[i]
        if maxradius:
            xgrid_outside = xgrid_outside+gridsize/2
            ygrid_outside = ygrid_outside+gridsize/2
            zgrid_outside = zgrid_outside+gridsize/2
            indices_outside = [labeldict[i] for i in indices_outside]
            max_radius = np.max(values_inside)
            indices_outside = np.array(indices_outside)
            indices_outside[values_outside>max_radius] = 0
            for i in range(len(xgrid_outside)):
                self.label_image[np.floor(xgrid_outside[i]).astype(np.uint), np.floor(
                    ygrid_outside[i]).astype(np.uint), np.floor(zgrid_outside[i]).astype(np.uint)] = indices_outside[i]
        
        self.boundary_image = self.labelBoundaries(self.label_image)
        if morph > 0:
            # use inverse boundary image for morphological operations
            boundary_image = np.logical_not(self.boundary_image)
            # set background balck
            boundary_image[self.label_image==0] = False
            # ball shape structuring element
            ball = morphology.ball(morph)
            # first close the image to get round cell edges
            openedimg = binary_opening(np.pad(boundary_image,((morph,morph),(morph,morph),(morph,morph)), mode='symmetric'), ball)
            # only use the round cell edges at he border of the overall shape
            padded = np.pad((self.label_image > 0),((morph,morph),(morph,morph),(morph,morph)), mode='symmetric')
            eroded = binary_erosion(padded, ball)
            border = np.logical_xor(padded, eroded)[morph:-morph,morph:-morph,morph:-morph]
            zeros = np.logical_and(border, np.logical_not(binary_fill_holes(binary_erosion(binary_dilation(openedimg))))[morph:-morph,morph:-morph,morph:-morph])
            self.label_image[zeros] = 0

    def __cart2spher(self, x, y, z):
        """cartesian to spherical coordinates"""
        xy = x**2+y**2
        r = np.sqrt(xy+z**2)
        t = np.arctan2(z, np.sqrt(xy))+np.pi/2
        p = np.arctan2(y, x)+np.pi
        return r, t, p

    def __spher2cart(self, r, t, p):
        """spherical to cartesian coordinate grid"""

        sshape = len(p), len(t)

        x = np.sin(t)[:, None] * np.cos(p)[None, :]
        y = np.sin(t)[:, None] * np.sin(p)[None, :]
        z = np.cos(t)[:, None] * np.ones_like(p)[None, :]

        points = np.vstack((x.flatten(), y.flatten(), z.flatten()))
        points *= r
        x = points[0].reshape(sshape)
        y = points[1].reshape(sshape)
        z = points[2].reshape(sshape)
        return np.array([x, y, z])

    def __grid2spher(self, g, s=1, variation_factor=8):
        """convert spherical harmonics grid to spherical coordinate points"""
        nlat, nlon = g.nlat, g.nlon

        data = g.data

        lats = g.lats()
        lons = g.lons()

        lats_circular = np.append(lats, [-90.])
        lons_circular = np.append(lons, [360])

        #nlats_circular = len(lats_circular)
        #nlons_circular = len(lons_circular)

        # make uv sphere and store all points
        p = np.radians(lons_circular)
        t = np.radians(90. - lats_circular)
        # fill data for all points
        magn_point = np.zeros((nlat + 1, nlon + 1))
        magn_point[:-1, :-1] = data
        # avoid holes and peaks on surface
        magn_point[-1, :] = magn_point[-2, :]
        magn_point[:-1, -1] = data[:, 0]

        if variation_factor != 0:
            # compute maximum
            magnmax_point = np.max(np.abs(magn_point))

            # displace the points
            r = (1+magn_point / magnmax_point/variation_factor)*(s/(1+1/variation_factor))
        else:
            r = magn_point
        return r, t, p

    def __delOutOfRange(self, r, t, p, x, y, z):
        """Delete points not lying in the sphere"""
        # convert to spherical coordinates
        r1, t1, p1 = self.__cart2spher(x, y, z)
        # bin the angles like the spherical grid
        bint = t[1:len(t)]-(t[2]-t[1])/2
        binp = p[1:len(p)]-(p[2]-p[1])/2
        indp = np.digitize(p1, binp)
        indt = np.digitize(t1, bint)
        # get all points with a radius greater the the sphere at a specific point
        delind = [i for i in range(len(r1)) if r1[i] > r[indt[i], indp[i]]]
        return delind
    def saveLabel(self, filepath):
        io.imsave(filepath, self.label_image)

    def saveBoundaries(self, filepath):
        io.imsave(filepath, self.boundary_image)

    def labelBoundaries(self, input):
        ## Original Code
        # output = np.zeros(input.shape, dtype=np.bool)
        # for z in range(input.shape[2]-1):
        #     for y in range(input.shape[1]-1):
        #         for x in range(input.shape[0]-1):
        #             value = input[x,y,z]
        #             if input[x+1,y,z] != value:
        #                 output[x,y,z] = True
        #             if input[x,y+1,z] != value:
        #                 output[x,y,z] = True
        #             if input[x,y,z+1] != value:
        #                 output[x,y,z] = True
        filt = np.zeros((3,3,3))
        filt[1,1,1] = 3
        filt[2,1,1] = -1
        filt[1,2,1] = -1
        filt[1,1,2] = -1
        output = convolve(input,filt, mode='nearest') != 0
        return output
    
    def __createsphere(self, size, radius, roundness = None, init_mode='exponentialfunction', cos_freq=2, variation_factor=8, deg=24):
        coeffs = 2*size
        # generate power spectrum with exponentially falling magnitudes for smooth sphere
        degrees = np.arange(coeffs, dtype=float)
        degrees[0] = np.inf
        if init_mode == 'exponentialfunction_with_cosine':
            if roundness is None:
                roundness = 0.1
            power = np.exp(-roundness*degrees)
            degrees[0] = 0
            power = power
            power[deg//2:] += 0.00001*np.absolute(np.cos(np.arange(len(power)-deg//2)/cos_freq*np.pi))*np.exp(-np.arange(len(power)-deg//2)*0.06)
        elif init_mode == 'exponentialfunction':
            if roundness is None:
                roundness = random.random()*2+0.1
            power = np.exp(-roundness*degrees)
        elif init_mode == 'exponential':
            if roundness is None:
                roundness =  random.random()*2+2
            power = degrees**(-roundness)
        
        # spherical harmonic coefficient
        self.clm = pyshtools.SHCoeffs.from_random(power)
        # generate real grid
        self.grid = self.clm.expand()
        # get spherical coordinate grid from the spherical harmonic grid
        r, t, p = self.__grid2spher(self.grid, radius, variation_factor=variation_factor)
        return r, t, p
