#!/usr/bin/env python
# coding: utf-8

# In[1]:



import gudhi as gd
import numpy as np
from mma import *
from classif_helper import *
from sklearn.neighbors import KernelDensity
from os.path import expanduser
from os import walk
from pandas import read_csv
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.pipeline import Pipeline
from mma_classif import *
from sys import argv

bdw_ = float(argv[1])
p_ = int(argv[2])

ncore = cpu_count()
# In[2]:


force_computation=False


# In[3]:


def randimg(xtrain):
    return from_numpy(np.random.uniform(size=xtrain.shape)).float()


# In[4]:


def compute_alpha(x, **kwargs):
    alphacplx = gd.AlphaComplex(points=x)
    st = alphacplx.create_simplex_tree(max_alpha_square = kwargs.get("threshold",np.inf))
    return splx2bf(st),[alphacplx.get_point(i) for i in range(len(x))]
def get_bf(i:int=0, **kwargs):
    x = np.unique(X[i], axis=0)
    l = np.array(labels[i])
    kde = KernelDensity(bandwidth=kwargs.get("kde_bandwidth",1), kernel=kwargs.get("kde_kernel", "gaussian")).fit(x)
    [b,f1],y = compute_alpha(x, **kwargs)
    f2 = -np.array(kde.score_samples(y))
    f1 =  np.array(f1)*kwargs.get("scale",1)[0]
    f2 *= kwargs.get("scale",1)[1]
    return b,[f1,f2]   


# # Dataset

# In[5]:


DATASET_PATH = expanduser("~/Datasets/1.5mmRegions/")


# In[6]:

print("Dataset", DATASET_PATH)
def get_regions():
    X, labels = [],[]
    for label in ["FoxP3", "CD8", "CD68"]:
#     for label in ["FoxP3", "CD8"]:
        for root, dirs, files in walk(DATASET_PATH + label+"/"):
            for name in files:
                X.append(np.array(read_csv(DATASET_PATH+label+"/"+name))/1500)
                labels.append(label)
    return X, labels
X, labels = get_regions();
elabels = LabelEncoder().fit_transform(np.array(labels))


# In[7]:


i = np.random.randint(len(X))
x = X[i]
plt.scatter(x[:,0], x[:,1])
plt.title(f"Label : {labels[i]}");


# # image computation

# ## parameters

# In[8]:


# p=1 norm = 0 imgbdw = 0.01 "box":[[0,0.75], [1,2]]
params = {"box":[[0,0], [2,2]], 
          "dimensions":[0,1],                            # Maximum homology dimension to compute
		  "threshold":2,
		  # Image parameters 
          "kde_bandwidth":0.5,
          "ps":[1], "img_bandwidths": [0.001],
          "precision":0.0001,    # Module approximation precision
          "resolution":[100,100],                  # resolution of the images
          "normalizes":[0],                              # Normalize the images if 1
          "flatten":False,
		  "sigmoid":False,
          "scale":[500,1],
          "kernel":"gaussian",
}
r1,r2 = params["resolution"]
nparams = len(params["dimensions"])* len(params["ps"]) * len(params["img_bandwidths"]) * len(params["normalizes"])
nparams


# In[32]:


# i = np.random.randint(len(X))
# # compute_img(get_bf(i, **params), plot=1,size=5,**params);
# b,f = get_bf(i, **params)
# mod_i = approx(b,f, box= [[0,0], [2,2]])
# mod_i.image(**params)
# mod_i.landscape(dimension = 0);
# mod_i.landscape(dimension = 1);


# In[10]:


# np.random.choice(range(len(X)), 5, replace=False)


# In[11]:


#Compute bnds
m = []
M=[]
for i in np.random.choice(range(len(X)), 100, replace=False):
   _,F = get_bf(i, **params)
   m.append([min(f) for f in F])
   M.append([max(f) for f in F])
m = np.min(m, axis=0)
M = np.max(M, axis=0)
box = [m,M]
print(box)


# In[29]:


from os.path import exists
if force_computation or not exists(DATASET_PATH+f"mma_mods2_{len(X)-1}.pkl"):
    print("Computing modules...", flush=1)
    compute_mods(range(len(X)), get_bf, n_jobs=ncore, save=DATASET_PATH+"mma_mods2_", **params)
else:
    print("Skipping computation of modules")
# out = compute_imgs(indices,get_bf, n_jobs=8, multithreads=True,**params)
# out = np.array(out, dtype=np.float32)
if force_computation:	exit()

# In[15]:


X_str = [DATASET_PATH+f"mma_mods2_{i}.pkl" for i in range(len(X))]
# from sklearn.model_selection import train_test_split
# train_index, test_index = train_test_split(range(len(labels)), test_size=0.25)
# X_train = [X_str[i] for i in train_index]
# X_test = [X_str[i] for i in test_index]
# y_train, y_test = elabels[train_index], elabels[test_index]


# In[35]:


img_params = {
    "resolution" : [200,200],
    "normalize" : 0,
    "plot" : False,
    "box" : [[0,0.], [1,2]],
    "p" : p_,
    "normalize" : False,
    "bandwidth" : bdw_,
}
def get_img(mod, **img_params):
    if type(mod) is str:
        mod_ = from_dump(pickle.load(open(mod, "rb")))
    elif type(mod) is list:
        mod_ = from_dump(mod)
    elif type(mod) is PyModule:
        mod_ = mod
    else:
        print("Bad input")
        return
    return np.concatenate([mod_.image(dimension=d, **img_params) for d in [0,1]]).flatten()

print(img_params, flush=1)
print("Computing imgs...", flush=True)
with Parallel(n_jobs=int(ncore/4)) as p:
    X_img = p(delayed(get_img)(x,**img_params) for x in tqdm(X_str))



# In[38]:


X_ = np.array(X_img)


# In[39]:

print(img_params, flush=1)
print(kfold_acc([RandomForestClassifier(n_estimators=300, n_jobs=ncore)], X_, elabels,k=10), flush=1)






