#!/usr/bin/env python
# coding: utf-8

# In[1]:



import gudhi as gd
import numpy as np
from mma import *
from classif_helper import *
from sklearn.neighbors import KernelDensity
from os.path import expanduser
from os import walk
from pandas import read_csv
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from sklearn.pipeline import Pipeline
from mma_classif import *
from sys import argv



ncore = cpu_count()
# In[2]:

# Variable that forces the computation of the modules even if they are already computed
force_computation=False




# In[4]:

# Computes the two filtration of the i-th element of the dataset
def compute_alpha(x, **kwargs):
    alphacplx = gd.AlphaComplex(points=x)
    st = alphacplx.create_simplex_tree(max_alpha_square = kwargs.get("threshold",np.inf))
    return splx2bf(st),[alphacplx.get_point(i) for i in range(len(x))]
def get_bf(i:int=0, **kwargs):
    x = np.unique(X[i], axis=0)
    l = np.array(labels[i])
    kde = KernelDensity(bandwidth=kwargs.get("kde_bandwidth",1), kernel=kwargs.get("kde_kernel", "gaussian")).fit(x)
    [b,f1],y = compute_alpha(x, **kwargs)
    f2 = -np.array(kde.score_samples(y))
    f1 =  np.array(f1)*kwargs.get("scale",1)[0]
    f2 *= kwargs.get("scale",1)[1]
    return b,[f1,f2]   


# # Dataset

# In[5]:


DATASET_PATH = expanduser("~/Datasets/1.5mmRegions/")


# In[6]:

print("Dataset", DATASET_PATH)
# Retrieves the immuno dataset
def get_regions():
    X, labels = [],[]
    for label in ["FoxP3", "CD8", "CD68"]:
#     for label in ["FoxP3", "CD8"]:
        for root, dirs, files in walk(DATASET_PATH + label+"/"):
            for name in files:
                X.append(np.array(read_csv(DATASET_PATH+label+"/"+name))/1500)
                labels.append(label)
    return X, labels
X, labels = get_regions();
elabels = LabelEncoder().fit_transform(np.array(labels))


# In[7]:

#
# i = np.random.randint(len(X))
# x = X[i]
# plt.scatter(x[:,0], x[:,1])
# plt.title(f"Label : {labels[i]}");


# # image computation

# ## parameters

# In[8]:

# parameters to compute the modules
params = {"box":[[0,0], [1,2]],
          "dimensions":[0,1],                            # Maximum homology dimension to compute
		  "threshold":2,
		  # Image parameters 
          "kde_bandwidth":0.5,
          "precision":0.001,    # Module approximation precision
          "flatten":False,
		  "sigmoid":False,
          "scale":[200,1],      # scale of geometry (ie AlphaComplex) filtration vs log density estimation filtration
          "kernel":"gaussian",
}



# In[32]:


# i = np.random.randint(len(X))
# # compute_img(get_bf(i, **params), plot=1,size=5,**params);
# b,f = get_bf(i, **params)
# mod_i = approx(b,f, box= [[0,0], [2,2]])
# mod_i.image(**params)
# mod_i.landscape(dimension = 0);
# mod_i.landscape(dimension = 1);


# In[10]:


# np.random.choice(range(len(X)), 5, replace=False)


# In[11]:


#Compute filtration bounds for the filtration
m = []
M=[]
for i in np.random.choice(range(len(X)), 100, replace=False):
   _,F = get_bf(i, **params)
   m.append([min(f) for f in F])
   M.append([max(f) for f in F])
m = np.min(m, axis=0)
M = np.max(M, axis=0)
box = [m,M]
print(box)


# In[29]:


from os.path import exists
# Computes the modules of the dataset, and saves it to pickle files (if necessary)
if force_computation or not exists(DATASET_PATH+f"mma_mods_{len(X)-1}.pkl"):
    print("Computing modules...", flush=1)
    compute_mods(range(len(X)), get_bf, n_jobs=ncore, save=DATASET_PATH+"mma_mods_", **params)
else:
    print("Skipping computation of modules")
# out = compute_imgs(indices,get_bf, n_jobs=8, multithreads=True,**params)
# out = np.array(out, dtype=np.float32)


# In[15]:


X_str = [DATASET_PATH+f"mma_mods_{i}.pkl" for i in range(len(X))]
# from sklearn.model_selection import train_test_split
# train_index, test_index = train_test_split(range(len(labels)), test_size=0.25)
# X_train = [X_str[i] for i in train_index]
# X_test = [X_str[i] for i in test_index]
# y_train, y_test = elabels[train_index], elabels[test_index]


# In[35]:


img_params = {
    "resolution" : [200,200],
    "plot" : False,
    "box" : [[0,0.5], [2,2]],
}
def get_img(mod, **img_params):
    if type(mod) is str:
        mod_ = from_dump(pickle.load(open(mod, "rb")))
    elif type(mod) is list:
        mod_ = from_dump(mod)
    elif type(mod) is PyModule:
        mod_ = mod
    else:
        print("Bad input")
        return
    return np.concatenate([np.sum(mod_.landscapes(dimension=d, ks=range(5), **img_params), axis = 0) for d in [0,1]]).flatten()
print("Computing imgs...", flush=True)
with Parallel(n_jobs=int(ncore/4)) as p:
    X_img = p(delayed(get_img)(x,**img_params) for x in tqdm(X_str))



# In[38]:


X_ = np.array(X_img)


# In[39]:

print(img_params, flush=1) # prints image parameters
print(kfold_acc([RandomForestClassifier(n_estimators=300, n_jobs=ncore)], X_, elabels,k=10), flush=1) # prints the k-fold accuracy




