#!/usr/bin/env python
# coding: utf-8

# In[1]:



import gudhi as gd
import numpy as np
from mma import *
from classif_helper import *
from sklearn.neighbors import KernelDensity
from os.path import expanduser
from os import walk
from pandas import read_csv
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from sklearn.pipeline import Pipeline
from mma_classif import *
from sys import argv



ncore = cpu_count()
# In[2]:

# Variable that forces the computation of the modules even if they are already computed
force_computation=False




# In[4]:

# Computes the two filtration of the i-th element of the dataset
def compute_alpha(x, **kwargs):
    alphacplx = gd.AlphaComplex(points=x)
    st = alphacplx.create_simplex_tree(max_alpha_square = kwargs.get("threshold",np.inf))
    return splx2bf(st),[alphacplx.get_point(i) for i in range(len(x))]
def get_bf(i:int=0, **kwargs):
    x = np.unique(X[i], axis=0)
    l = np.array(labels[i])
    kde = KernelDensity(bandwidth=kwargs.get("kde_bandwidth",1), kernel=kwargs.get("kde_kernel", "gaussian")).fit(x)
    [b,f1],y = compute_alpha(x, **kwargs)
    f2 = -np.array(kde.score_samples(y))
    f1 =  np.array(f1)*kwargs.get("scale",1)[0]
    f2 *= kwargs.get("scale",1)[1]
    return b,[f1,f2]   


# # Dataset

# In[5]:


DATASET_PATH = expanduser("~/Datasets/1.5mmRegions/")


# In[6]:

print("Dataset", DATASET_PATH)
# Retrieves the immuno dataset
def get_regions():
    X, labels = [],[]
    for label in ["FoxP3", "CD8", "CD68"]:
#     for label in ["FoxP3", "CD8"]:
        for root, dirs, files in walk(DATASET_PATH + label+"/"):
            for name in files:
                X.append(np.array(read_csv(DATASET_PATH+label+"/"+name))/1500)
                labels.append(label)
    return X, labels
X, labels = get_regions();
elabels = LabelEncoder().fit_transform(np.array(labels))

X_str = [DATASET_PATH+f"mma_mods_{i}.pkl" for i in range(len(X))]

img_params = {
    "resolution" : [200,200],
    "plot" : 1,
    "box" : [[0,0], [2,2]],
}

from random import choice
mod = from_dump(pickle.load(open(choice(X_str), "rb")))


plt.figure()
dim = 0; p=1; bdw=0.001;
mod.image(dimension = -1, p = p, bandwidth=bdw, cb=1, **img_params)
plt.savefig(f"img_H{dim}_p{p}_bdw{bdw}.png", dpi=200)
plt.figure()
dim = 1; p=0; bdw=0.001;
mod.image(dimension = -1, p = p, bandwidth=bdw, cb=1, **img_params)
plt.savefig(f"img_H{dim}_p{p}_bdw{bdw}.png", dpi=200)
plt.figure()
dim=0
mod.landscapes(dimension =dim, ks = range(5), **img_params)
plt.colorbar()
plt.savefig(f"MPL_H{dim}.png", dpi=200)
plt.figure()
dim=1
mod.landscapes(dimension =dim, ks = range(5), **img_params)
plt.colorbar()
plt.savefig(f"MPL_H{dim}.png", dpi=200)

#def get_img(mod, **img_params):
#    if type(mod) is str:
#        mod_ = from_dump(pickle.load(open(mod, "rb")))
#    elif type(mod) is list:
#        mod_ = from_dump(mod)
#    elif type(mod) is PyModule:
#        mod_ = mod
#    else:
#        print("Bad input")
#        return
#    return np.concatenate([np.sum(mod_.landscapes(dimension=d, ks=range(5), **img_params), axis = 0) for d in [0,1]]).flatten()
#print("Computing imgs...", flush=True)
#with Parallel(n_jobs=int(ncore/4)) as p:
#    X_img = p(delayed(get_img)(x,**img_params) for x in tqdm(X_str))



