#!/usr/bin/env python
# coding: utf-8

# # Prerequises

# In[13]:


print("Loading dependencies", flush=True)
import numpy as np
import matplotlib.pyplot as plt
import mma
from classif_helper import *
import gudhi as gd
from sklearn.neighbors import KernelDensity
import pickle

from sys import argv
kmin = int(argv[1])
kmax = int(argv[2])
nsamples = int(argv[3])
# # Dataset generation

# In[2]:


def pt(low= 1, high=1.1, k=2, sigma=1):
    n = np.random.normal(loc=0,scale=sigma)
    r = np.sqrt(np.random.uniform(low = low, high = high**2)) - 0.1/sigma*(1-np.abs(n))
    θ = np.random.choice(range(k)) * 2*np.pi / k + n
    return r*np.cos(θ), r* np.sin(θ)
def orbit(n:int=100, r=0.5, x0=[])->list:
    point_list=[]
    if len(x0) != 2:
        x,y=np.random.uniform(size=2)
    else:
        x,y = x0
    point_list.append([x,y])
    for _ in range(n-1):
        x = (x + r*y*(1-y)) %1
        y = (y + r*x*(1-x)) %1
        point_list.append([x,y])
    return point_list
def get_pts(dataset:str="annulus", npts:int=100,  **kwargs)->np.ndarray:
    match dataset:
        case "annulus":
            return np.array([pt(**kwargs) for _ in range(npts)])
        case "orbit":
            return np.array(orbit(npts, **kwargs))
        case _:
            return np.array([])


# In[3]:


# Parameters
npts = 100_000
k = 3
s= 0.5

# In[4]:


print("Generating dataset...", flush=True)
X = get_pts(npts=npts, dataset="annulus", k=k, sigma = s)
heatmap, xedges, yedges = np.histogram2d(X[:,0], X[:,1], bins=100, density = 1)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.imshow(heatmap.T, origin='lower')
plt.colorbar()
plt.grid(None)
plt.savefig(f"images/annulus_heatmap_{npts}_pts_{k}_modes.svg")
plt.clf()


# # Filtrations

# In[5]:


params = {
	"n_jobs":int(cpu_count()),
	"kmin":kmin,"kmax":kmax,"nsamples":nsamples,
	"precision":0.01,
	"degree":1,
    "resolution":[50,50],
	"kde_bandwidth":0.1,
	"box":[[-0.1,-0.1],[10,10]],
	"kde_kernel": "gaussian",
	"normalize":1,
	"bandwidth":0.1,
	"ps":[0,0.5,1,2, np.inf],"threshold":10,
	"flatten":False,
}


# In[6]:


def get_bf(k, **kwargs):
    sample = X[1:k]
    alphacplx = gd.AlphaComplex(points=sample)
    # st = gd.RipsComplex(points=sample, max_edge_length = 0.25).create_simplex_tree()
    st = alphacplx.create_simplex_tree(max_alpha_square=np.sqrt(kwargs.get("threshold",4)))

    points = np.array([alphacplx.get_point(i) for i in range(st.num_vertices())])
    # points = sample
    kde = KernelDensity(kernel=kwargs.get("kde_kernel", 'gaussian'), bandwidth=kwargs.get("kde_bandwidth",0.5)).fit(sample)
    codensity_filtration = -np.array(kde.score_samples(points))
    
    st = mma.SimplexTreeMulti(st,num_parameters=2)
    st.fill_lowerstar(codensity_filtration, parameter=1)
    return st
def mod_dump(k:int):
    st = get_bf(k, **params)
    return st.persistence_approximation(**params).dump()


# # Computation

# In[16]:


start = params["kmin"]
stop = params["kmax"]
num = params["nsamples"]


# In[17]:

lin_it = np.linspace(start=start, stop=stop, num=num, dtype=int)
log_it = np.logspace(start=np.log10(start), stop=np.log10(stop), num=num, dtype=int)
iterator = np.unique(np.concatenate([lin_it, log_it]))
iterator.sort()
with open(f"modules/synthetic1/iterator_{start}_{stop}_{num}.np", "wb") as f:
    np.save(f,iterator) 

# In[21]:
print("Computing modules...", flush=True)

approximation_modules = compute_mods(iterator, get_bf, dump=True,save=f"modules/synthetic1/module_{start}_{stop}_{num}_", **params)
# print("Saving modules...")

# with open(f"modules/synthetic/module_{start}_{stop}_{num}.pkl", 'wb') as file:
#     pickle.dump([approximation_modules, params], file)
print("Done !")




