#!/usr/bin/env python
# coding: utf-8

# # Prerequises

# In[13]:


print("Loading dependencies", flush=True)
import numpy as np
import matplotlib.pyplot as plt
from mma import *
from classif_helper import *
import gudhi as gd
from sklearn.neighbors import KernelDensity
import pickle

from sys import argv
kmin = int(argv[1])
kmax = int(argv[2])
nsamples = int(argv[3])
# # Dataset generation

# In[2]:


n_pts = 5_000
np.random.seed(0)
dataset = np.block([
    [np.array(noisy_annulus(0.4,0.45,(int)(n_pts*2/10), center = [1.2,1.3]))],
    [np.array(noisy_annulus(0.3,0.31,(int)(n_pts*2/10), center = [0.2,-1]))],
    [np.array(noisy_annulus(0.2,0.201,(int)(n_pts*2/10), center = [-1,0.5]))],
    [np.random.uniform(low=-2,high=2,size=((int)(n_pts*2/10),2))]
])
np.random.shuffle(dataset)


# # Filtrations

# In[5]:


params = {
	"n_jobs":int(cpu_count()),
	"kmin":kmin,"kmax":kmax,"nsamples":nsamples,
	"precision":0.01,
	"dimension":1,"resolution":[50,50],
	"kde_bandwidth":0.05,
	"box":[[-0.1,-1],[1,2]],
	"kde_kernel": "gaussian",
	"normalize":1,
	"bandwidth":0.1,
	"ps":[0,0.5,1,2, np.inf],"threshold":10,
	"flatten":False,
}


# In[6]:


def get_bf(k, **params):
    X = dataset[1:k]
    ripscplx = gd.RipsComplex(points=X, max_edge_length=1) 
    simplex_tree = ripscplx.create_simplex_tree(max_dimension=2)
    kde = KernelDensity(kernel='gaussian', bandwidth=0.05).fit(X)
    density = kde.score_samples(X)
    filtration_density = -np.array(density)
    boundary, filtration_alpha = splx2bf(simplex_tree)
    bifiltration = [filtration_alpha, filtration_density]
    return boundary, bifiltration
def mod_dump(k:int):
    b, f = get_bf(k, **params)
    return approx(b,f,**params).dump()


# # Computation

# In[16]:


start = params["kmin"]
stop = params["kmax"]
num = params["nsamples"]


# In[17]:


iterator = np.linspace(start=start, stop=stop, num=num, dtype=int)


# In[21]:
print("Computing modules...", flush=True)

compute_mods(iterator, get_bf, dump=True, save=f"modules/synthetic2/module_{start}_{stop}_{num}_", **params)
# print("Saving modules...")

# with open(f"modules/cv_synthetic2_module_{start}_{stop}_{num}.pkl", 'wb') as file:
#     pickle.dump([approximation_modules, params], file)
print("Done !")




