#!/usr/bin/env python
# coding: utf-8
"""
Computes classification task from dataset.

Input parameters
- dataset name : str : "immuno" or UCR time serie, eg, "Coffee"
- res:int resolution
- recompute:bool if true, recomputes modules
- number of cores to use: int
"""


# In[2]:

print("Starting...")
print("Loading libraries...")
import numpy as np
import pickle as pck
import pandas as pd
#import matplotlib.pyplot as plt
from gudhi.point_cloud.timedelay import TimeDelayEmbedding
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KernelDensity
from tqdm import tqdm
from pandas import read_csv
from joblib import Parallel, delayed, cpu_count
import matplotlib.pyplot as plt

import gudhi as gd
import mma
from multipers import *
from mma_classif import *


# In[3]:





# Choose the dataset! They can be obtained there: https://www.cs.ucr.edu/~eamonn/time_series_data_2018/

# In[4]:

from sys import argv 
dataset = str(argv[1])
res = int(argv[2])
if len(argv) >=  4:
	force_computation = bool(argv[3])
else:
	force_computation = False
if len(argv) >= 5:
	ncore = int(argv[4])
else:
	ncore = cpu_count()

print(len(argv))
print("Number of cores :", ncore, flush=True)
print("Script launched with arguments", dataset, res, force_computation, ncore)


# Define all hyperparameters.

# In[5]:


# time delay embedding
dimension = 3
delay     = 1
skip      = 1

# vineyards
nlines = 200
noise  = 0

# DTM parameters
m = 0.1

# KDE parameters
kde_bandwidth = 0.5

KDE = True

# image parameters
#res = 50

# ML parameters
xgbc =			 XGBClassifier(random_state=1)
rfc =  			RandomForestClassifier(random_state=1,n_estimators=500)
cv  =		 	5


# Define some global variables.

# In[6]:


from os.path import expanduser, exists

PATH = expanduser("~/Datasets/")
UCR_path       = PATH +"UCR/"+ dataset + "/"
IMMUNO_path = PATH+ "1.5mmRegions/"
list_filts = ["Alpha-DTM-0", "Alpha-DTM-1"]


# Read the data sets, and impute the missing values.

# In[7]:


immuno = (dataset == "immuno")
UCR = not immuno


# In[8]:

print("Reading dataset...", flush=True)
if UCR:
    path = UCR_path
    X1 = np.array(pd.read_csv(path + dataset + "_TRAIN.tsv", sep="\t", header=None))
    X2 = np.array(pd.read_csv(path + dataset + "_TEST.tsv",  sep="\t", header=None))
    X = np.vstack([X1, X2])
    split = len(X1)
elif immuno:
    path = IMMUNO_path
    def get_regions():
        X, labels = [],[]
        for label in ["FoxP3", "CD8", "CD68"]:
    #     for label in ["FoxP3", "CD8"]:
            for root, dirs, files in os.walk(IMMUNO_path + label+"/"):
                for name in files:
                    X.append(np.array(read_csv(IMMUNO_path+label+"/"+name))/1500)
                    labels.append(label)
        return X, LabelEncoder().fit_transform(labels)
    TS, L =  get_regions()
    split = np.floor(len(TS) * 0.75).astype(int)

print("Dataset : ", path)
# In[9]:


if UCR:
    L, TS = X[:,0], X[:,1:]
    imp = SimpleImputer(missing_values=np.nan, strategy="mean")
    TS = imp.fit_transform(TS)
    tde = TimeDelayEmbedding(dim=dimension, delay=delay, skip=skip)
nts = len(TS)
np.savetxt(path + "labels.txt", L)


# # Decompositions

# Compute maximal pairwise distance for Alpha complex.

# In[10]:

print("Computing filtration bounds", flush=True)

ds = []
for tsidx in range(30):
    X = tde(TS[tsidx,:]) if UCR else TS[tsidx]
    ds.append(pairwise_distances(X).flatten())
allds = np.concatenate(ds)
maxd = np.max(allds)

# KDE parameters
# kde_bandwidth = 0.1* maxd
# Compute bounding rectangle for multiparameter persistence.

# In[11]:


mxf, Mxf, myf, Myf = np.inf, -np.inf, np.inf, -np.inf

for tsidx in range(0, nts):

    # Compute min and max of first filtration (Alpha)
    X = tde(TS[tsidx,:]) if UCR else TS[tsidx]
    st = gd.AlphaComplex(points=X).create_simplex_tree(max_alpha_square=maxd)
    fs = [f for (s,f) in st.get_filtration()]
    mxf, Mxf = min(mxf, min(fs)), max(Mxf, max(fs))
    
    # Compute min and max of second filtration (lower-star on DTM)
    if not KDE:
        fs = DTM(X, X, m)
        # density = DTM(X, X, m)
        # for (s,f) in st.get_skeleton(0):
        #     st.assign_filtration(s, density[s[0]])
        # for (s,f) in st.get_filtration():
        #     if len(s) > 1:
        #         st.assign_filtration(s, -1e10)
        # st.make_filtration_non_decreasing()
        # fs = [f for (s,f) in st.get_filtration()]
    if KDE:
        fs = KernelDensity(bandwidth=kde_bandwidth).fit(X).score_samples(X)
    myf, Myf = min(myf, min(fs)), max(Myf, max(fs))


# Compte all multipersistence decompositions.

# In[12]:


# ldgms0, mdgms0 = [], []
# ldgms1, mdgms1 = [], []
# count = 0

# for tsidx in tqdm(range(0, nts)):

#     # Compute time delay embedding and DTM density
#     X = tde(TS[tsidx,:])
#     density = np.squeeze(DTM(X, X, m))
    
#     # Create Alpha complex
#     dcomplex = gd.AlphaComplex(points=X)
#     st = dcomplex.create_simplex_tree(max_alpha_square=maxd)

#     # Use first barycentric subdivision to turn Alpha into a lower-star
#     list_splxs = []
#     st2 = gd.SimplexTree()
#     for (s,_) in st.get_filtration():
#         st2.insert(s, max([density[v] for v in s]))
#         list_splxs.append((s, max([density[v] for v in s])))
#     bary1 = barycentric_subdivision(st, use_sqrt=False)
#     bary2 = barycentric_subdivision(st2, list_splx=list_splxs)

#     # Write inputs for vineyards algorithm
#     cname, fname = path + "complex" + str(count) + ".txt", path + "filtrations" + str(count) + ".txt"
#     complexfile, filtfile = open(cname, "w"), open(fname, "w")
#     for (s,f) in bary1.get_filtration():
#         for v in s:
#             complexfile.write(str(v) + " ")
#         complexfile.write("\n")
#         if len(s) == 1:
#             filtfile.write(str(f) + " " + str(bary2.filtration(s)) + "\n")
#     complexfile.close()
#     filtfile.close()

#     # Compute the vineyards
#     mdg0, lines0, _, _ = sublevelsets_multipersistence(
#         "vineyards", cname, fname, homology=0, num_lines=nlines, corner="dg", extended=False, essential=False,
#         noise=noise, visu=False, plot_per_bar=False, 
#         bnds_filt=[mxf,Mxf,myf,Myf], bnds_visu=[mxf,Mxf,myf,Myf])
#     mdg1, lines1, _, _ = sublevelsets_multipersistence(
#         "vineyards", cname, fname, homology=1, num_lines=nlines, corner="dg", extended=False, essential=False,
#         noise=noise, visu=False, plot_per_bar=False, 
#         bnds_filt=[mxf,Mxf,myf,Myf], bnds_visu=[mxf,Mxf,myf,Myf])

#     mdgms0.append(mdg0)
#     mdgms1.append(mdg1)
#     count += 1

#     os.system("rm " + cname + "*")
#     os.system("rm " + fname + "*")


# Save the data.

# In[13]:


# np.save(path + "lines_Alpha-DTM-0", lines0)
# np.save(path + "lines_Alpha-DTM-1", lines1)
# np.save(path + "bnds_Alpha-DTM-0", np.array([mxf,Mxf,myf,Myf]))
# np.save(path + "bnds_Alpha-DTM-1", np.array([mxf,Mxf,myf,Myf]))
# pck.dump(mdgms0, open(path + "mdgms_Alpha-DTM-0.pkl", "wb"))
# pck.dump(mdgms1, open(path + "mdgms_Alpha-DTM-1.pkl", "wb"))


# Compute mma modules (lower star not needed)

# In[14]:


def compute_ts_mod(tsidx):
# Compute time delay embedding
    x = tde(TS[tsidx,:]) if UCR else TS[tsidx]
    kde = KernelDensity(bandwidth=kde_bandwidth).fit(x)
    x = np.unique(x, axis=0)

    # Create Alpha complex
    dcomplex = gd.AlphaComplex(points=x)
    st = dcomplex.create_simplex_tree(max_alpha_square=maxd)
    y = [dcomplex.get_point(i) for i, _ in enumerate(x)]

    # Compute density
    density_filtration = kde.score_samples(y)
    
    # Format input for mma
    st = mma.SimplexTreeMulti(st, num_parameters=2)
    st.fill_lowerstar(density_filtration, parameter=1)
    # boundary_matrix, alpha_filtration = mma.splx2bf(st)
    # bifiltration = [np.array(alpha_filtration), density_filtration]
    
    # Computes the module approximation
    mod = st.persistence_approximation(box = [[mxf, myf], [Mxf, Myf]])
    return mod.dump() # dump to be pickle-able.

# Computes the module list if it does not exists
if (not exists(path + "mma_mods.pkl")) or (force_computation):
    print("Computing modules")
    # with Pool(processes=ncore) as pool:
    # 	module_list = pool.map(compute_ts_mod, range(0, nts))
    module_list = Parallel(n_jobs=ncore, prefer="threads")(delayed(compute_ts_mod)(i) for i in range(0,nts))
    pck.dump(module_list, open(path + "mma_mods.pkl", "wb"))
    del module_list
else:
	print("Skipping module computation, as file already exists")




labels = np.loadtxt(path + "labels.txt", dtype=float)
labels = LabelEncoder().fit_transform(np.array([int(l) for l in labels]))
npoints = len(labels)
if UCR:
    train_index, test_index = np.arange(split), np.arange(split, npoints)
elif immuno:
    from sklearn.model_selection import train_test_split
    train_index, test_index = train_test_split(range(len(labels)), test_size=0.25)


# In[24]:


if UCR:
    xtrainf = [x.flatten() for i,x in enumerate(TS) if i in train_index]
    xtestf = [x.flatten() for i,x in enumerate(TS) if i in test_index]
    rfc.fit(xtrainf, labels[train_index])
    print("RandomForest on UCR, with dataset", dataset, " : ",  rfc.score(xtrainf, labels[train_index]),rfc.score(xtestf, labels[test_index]))


# Classify MMA images

# In[25]:


print("Bounding box : ", mxf, myf, Mxf, Myf)
print("Resolution : ", res)

# In[ ]:

print("Loading MMA modules from file...")
Xmmai = pck.load(open(path+"mma_mods.pkl", "rb"))
# Xmmai = [mma.from_dump(x) for x in Xmmai]
# Saves an image
plt.figure()
p_=1

from random import choice
k_ = choice(range(len(Xmmai)))
mod_ = mma.from_dump(Xmmai[k_]).image(p=p_)
plt.savefig(path+f"image{k_}_{p_}.png", dpi=200)
plt.clf()
del mod_, k_

# Cross validate the box
qxs = [0,0.05,0.1]
qys = [0,0.05,0.1, 0.2]
lx = Mxf - mxf
ly = Myf - myf
boxes = [[[mxf+qx*lx, myf],[Mxf-qx*lx,Myf-qy*ly]] for qx in qxs for qy in qys]
diam = min(lx, ly)
print("Diameter :", diam)
# Xmmai = [mma.from_dump(x) for x in Xmmai]
params_mmai = {
    "mmai__bdw":          [0.001*diam, 0.01*diam,0.1*diam, 0.5*diam], # the delta/bandwidth image parameter
    "mmai__power":        [0,1], # the p image parameter
    "mmai__normalize":    [0,1], # if datasets induces ~ the same number of summands, normalize =0 can be better
    "mmai__resolution":   [[res,res]], 
    "mmai__dimensions":   [[0,1]], # Take dimension 2 or not
    "mmai__plot":         [False],
    "mmai__box":          [[[mxf, myf],[Mxf,Myf]]], # biggest box
    "mmai__qx":           [0,0.1, 0.2], # Box cross validation
    "mmai__qy":           [0,0.1, 0.2],
    "clf__n_estimators":  [500],
}

pipe_mmai = Pipeline([("mmai", MMAImageWrapper()), ("clf", (rfc))])
X_train = [Xmmai[i] for i in train_index]
X_test = [Xmmai[i] for i in test_index]
y_train, y_test = labels[train_index], labels[test_index]
model = GridSearchCV(estimator=pipe_mmai, param_grid=params_mmai, cv=cv, verbose = 1,n_jobs=ncore)
print("Fitting GridSearch Classifiers...")
model.fit(X_train, y_train)
print("Done !")
score = model.score(X_test, y_test)
print(f"MMA-I train accuracy {dataset} = ", model.score(X_train, y_train))
print(f"MMA-I test accuracy {dataset} = ", model.score(X_test, y_test))
print(f"MMA-I best params {dataset} = ", model.best_params_)
pck.dump([model.best_params_, model.cv_results_, score], 
         open(path + f"modelMMAI_CV{cv}_res{res}.pkl", "wb"))
print("Details saved at : ", path + f"modelMMAI_CV{cv}_res{res}.pkl")

# In[27]:



# In[28]:







