#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 22 10:54:08 2024

@author: anonymous
"""

import numpy as np
import matplotlib.pyplot as plt 
import scipy.spatial as spatial
import pickle
import tqdm
import os 
import shutil



###############################################################################



def annulus_smaple(c,ri,ro,rho):
    """ Computes a sample of points from an annulus in the plane with center c,
    inner radius ri, outer radius ro and sampling density rho.

	Parameters
	----------
	c : Numpy array 
			The coordinates of the center
	ri : Float 
			the inner radius of the annulus
	ro : Float 
			the outer radius of the annulus
	rho : Float
			the sampling density
	Returns
	-------
	Numpy array
		The sampled points.
    """
    n=np.floor(rho*np.pi*(ro*ro-ri*ri)).astype(int)
    r=np.random.uniform(ri,ro,n)
    a=np.random.uniform(0,2*np.pi,n)
    return np.array([c[0]+np.multiply(r,np.cos(a)),c[1]+np.multiply(r,np.sin(a))]).T



###############################################################################



def create_pointcloud(b,rho_shape,rho_noise,sep,n,k):
    """ Creates a random pointcloud in the plane consisting of k random annuli,
        n-k random disks and uniform noise.

	Parameters
	----------
	b : Float 
			Initial boundaries of the canvas
	rho_shape : Float 
			smapling density of the annuli and disks
	rho_noise : Float 
			smapling density of the uniform noise
	sep : Float
			minimal separation of disks and annuli
    n : Int
            number of disks+number of annuli
    k : Int
            number of annuli
	Returns
	-------
	Numpy array
		The sampled points.
    """
    centers=np.empty((0,3))
    P=np.empty((0,2))
    L=np.concatenate((np.zeros(n-k), np.ones(k)))
    L=np.random.permutation(L)
    s=b
    for i in range(len(L)):
        if L[i]==0:
            r=np.random.uniform(2,5)
        else:
            r=np.random.uniform(4,5)
            ri=np.random.uniform(2,3)
        test=False
        cnt=0
        while(test==False):
            test=True
            c=np.random.uniform(-s,s,2)
            for j in range(len(centers)):
                if np.linalg.norm(c-centers[j,0:2])<r+centers[j,2]+sep: test=False
            if cnt==100:
                cnt=0
                s=s+1
            cnt=cnt+1
        centers=np.append(centers,np.array([[c[0],c[1],r]]),axis=0)
        if L[i]==0:
            P=np.append(P,annulus_smaple(c,0,r,rho_shape),axis=0)
        else:
            P=np.append(P,annulus_smaple(c,ri,r,rho_shape),axis=0)
    n_noise=np.floor((s+5)*(s+5)*rho_noise).astype(int)
    Noise=np.random.uniform(-(s+5),(s+5),(n_noise,2))
    P=np.concatenate((P,Noise))
    return P



###############################################################################



def generate_pointcloud_dataset(b,rho_shape,rho_Noise,sep,N,n):
    """ Creates a dataset of random pointclouds in the plane consisting of n classes
        of pointclouds with 0<=i<=n annuli and n-i disks + uniform noise with random density. 

	Parameters
	----------
	b : Float 
			Initial boundaries of the canvas
	rho_shape : Float 
			smapling density of the annuli and disks
	rho_noise : Float 
			smapling density of the uniform noise
	sep : Float
			minimal separation of disks and annuli
    n : Int
            number of classes of 0<=i<=n annuli + n-i disks
    N : Int
            number of pointclouds per class
	Returns
	-------
	[List,List]
		The list of generated pointclouds and the corresponding class labels
    """
    
    data_points=[]
    data_H1=[]
    for k in range(n):
        print('\n')
        print(k+1,'/',n,'\n')
        for i in tqdm.tqdm(range(0,N)):
            rho_noise=np.random.uniform(rho_Noise/2,rho_Noise)
            data_points.append(create_pointcloud(b,rho_shape,rho_noise,sep,n,k))
            data_H1.append(k)
    return data_points,data_H1



###############################################################################



def generate_orbit(N,r):
    """ Generates orbit of dynamical system 
    
	Parameters
	----------
    N : Int
        Number of points in the orbit
    r : Float
        Parameter of dynamical system
    """
    
    pnt=np.zeros((N,2))
    pnt[0]=np.random.uniform(low=0.0, high=1.0, size=2)
    
    for i in range(1,N-1):
        pnt[i,0]=(pnt[i-1,0]+r*pnt[i-1,1]*(1-pnt[i-1,1]))%1
        pnt[i,1]=(pnt[i-1,1]+r*pnt[i,0]*(1-pnt[i,0]))%1
        
    return pnt



###############################################################################


def generate_orbit_dataset(N,n,R):
    """ Generates the orbit dataset 
    
	Parameters
	----------
    N : Int
        Number of points in an orbit
    R : List
        List of parameters of dynamical system
    n : Int
        number of orbits per parameter value
    """
    
    data_points=[]
    data_labels=[]
    
    for i in range(len(R)):
        for j in range(n):
            data_points.append(generate_orbit(N,R[i]))
            data_labels.append(i)
            
    return data_points, data_labels



###############################################################################

def  baddeley_silverman(k):
    """ Simulates sample of Baddeley-Silverman process 
    
	Parameters
	----------
    k : Int
        Divide the interval [0,1] in k subintervals
    """
    
    P=[]    
    for i in range(k):
        for j in range(k):
            p=np.random.rand()
            if p>0.45 and p<0.55:
                x=np.random.uniform(i/k,(i+1)/k)
                y=np.random.uniform(j/k,(j+1)/k)
                P.append([x,y])
            if p>0.55:
                x1=np.random.uniform(i/k,(i+1)/k)
                y1=np.random.uniform(j/k,(j+1)/k)
                x2=np.random.uniform(i/k,(i+1)/k)
                y2=np.random.uniform(j/k,(j+1)/k)
                P.append([x1,y1])
                P.append([x2,y2])    
    return np.array(P)
            

###############################################################################


def generate_process_dataset():
    """ Generates the Processes dataset"""
    
    data_list=[]
    data_labels=[]

    for i in range(1000):
        P=np.genfromtxt("Data/Poisson_Processes/"+str(i)+".csv", delimiter=",", usemask=True)
        data_list.append(P)
        data_labels.append(0)
        
    for i in range(1000):
        P=np.genfromtxt("Data/Matern_Processes/"+str(i)+".csv", delimiter=",", usemask=True)
        data_list.append(P)
        data_labels.append(1)
          
    for i in range(1000):
        P=np.genfromtxt("Data/Strauss_Processes/"+str(i)+".csv", delimiter=",", usemask=True)
        data_list.append(P)
        data_labels.append(2)
        
    for i in range(1000):
        P=baddeley_silverman(14)
        data_list.append(P)
        data_labels.append(3)
          
        
    Le=np.array([len(data_list[i]) for i in range(len(data_list))])

    print('#Poisson: ',np.mean(Le[:1000]),'\n')  
    print('#Matern: ',np.mean(Le[1000:2000]),'\n')
    print('#Strauss: ',np.mean(Le[2000:3000]),'\n')
    print('#Baddeley_Silverman: ',np.mean(Le[3000:]),'\n')
    
    return data_list,data_labels

###############################################################################

def to_function_delaunay_inp(data_points,radius):
    """ Writes the pointclouds + estimated relative density scores of the points
        to files that can be used as input for function_delaunay. 
    
	Parameters
	----------
	data_points : List 
			List of numpy arrays with point clouds
	radius : Float 
			Radius parameter for local density estimation
    """
    print('\n')
    print('Write to files','\n')
    directory=os.path.join('Data','Pointcloud_Dataset')
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)
    
    for i in tqdm.tqdm(range(len(data_points))):
        P=data_points[i]
        tree=spatial.KDTree(P)
        neighbors=tree.query_ball_tree(tree, radius)
        frequency=np.array([len(k) for k in neighbors])
        d=np.divide(np.ones(frequency.shape[0]),frequency)
        L=np.append(P,np.reshape(d,(len(d),1)),axis=1)
        L=L[L[:, 2].argsort()]

        with open('Data/Pointcloud_Dataset/'+str(i)+'.txt', 'w') as f:
            for j in range(L.shape[0]):
                s=str(L[j,0])+' '+str(L[j,1])+' '+str(j)+'\n'
                f.write(s)
                
      


###############################################################################

"""Example of shape pointcloud"""

# np.random.seed(123456)

# n=5
# N=1000

# b=10
# rho_shape=3
# rho_noise=5
# sep=1

# radius=1

# P=create_pointcloud(b,rho_shape,rho_noise,sep,n,3)
# x,y=P[:,0:2].T
# plt.figure(figsize=(10,8))
# plt.scatter(x,y,35)
# plt.axis('off')
# plt.show()


###############################################################################

"""Creates the shape dataset"""

np.random.seed(123456)

n=5
N=1000

b=10
rho_shape=3
rho_noise=5
sep=1

radius=1

data_points,data_H1=generate_pointcloud_dataset(b,rho_shape,rho_noise,sep,N,n)

pickle.dump(data_points, open('Data/pointcloud_dataset.txt', 'wb'))
pickle.dump(data_H1, open('Data/pointcloud_labels.txt', 'wb'))

to_function_delaunay_inp(data_points,radius)


###############################################################################


"""Creates the Orbit5k/100k-dataset"""

# np.random.seed(123456)

# n=1000  #Orbit5k
# n=20000 #Orbit100k
# N=1000

# R=[2.5,3.5,4.0,4.1,4.3]

# radius=0.1

# data_points,data_labels=generate_orbit_dataset(N,n,R)

# pickle.dump(data_points, open('Data/pointcloud_dataset.txt', 'wb'))
# pickle.dump(data_labels, open('Data/pointcloud_labels.txt', 'wb'))

# to_function_delaunay_inp(data_points,radius)


###############################################################################


"""Creates the pointprocess dataset"""

# np.random.seed(123456)

# data_points,data_labels=generate_process_dataset()

# pickle.dump(data_points, open('Data/pointcloud_dataset.txt', 'wb'))
# pickle.dump(data_labels, open('Data/pointcloud_labels.txt', 'wb'))

# radius=0.1

# to_function_delaunay_inp(data_points,radius)



