#!/usr/bin/env python
# coding: utf-8
"""
Computes classification task from dataset.

Input parameters
- dataset name : str : "immuno" or UCR time serie
- res:int resolution
- recompute:bool if true, recomputes modules
- number of cores to use: int
"""


# In[2]:

print("Starting...")
print("Loading libraries...")
import numpy as np
import pickle as pck
import pandas as pd
#import matplotlib.pyplot as plt
from gudhi.point_cloud.timedelay import TimeDelayEmbedding
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KernelDensity
from tqdm import tqdm
from pandas import read_csv
from multiprocessing import Pool, cpu_count
from joblib import Parallel, delayed, cpu_count

import mma
from multipers import *
from mma_classif import *


# In[3]:





# Choose the dataset! They can be obtained there: https://www.cs.ucr.edu/~eamonn/time_series_data_2018/

# In[4]:

from sys import argv 
dataset = str(argv[1])
res = int(argv[2])
if len(argv) >=  4:
	force_computation = bool(argv[3])
else:
	force_computation = False
if len(argv) >= 5:
	ncore = int(argv[4])
else:
	ncore = cpu_count()

print(len(argv))
print("Number of cores :", ncore, flush=True)
print("Script launched with arguments", dataset, res, force_computation, ncore)


# Define all hyperparameters.

# In[5]:


# time delay embedding
dimension = 3
delay     = 1
skip      = 1

# vineyards
nlines = 200
noise  = 0

# DTM parameters
m = 0.1

# KDE parameters
kde_bandwidth = 0.5

KDE = False

# image parameters
#res = 50

# ML parameters
xgbc =			 XGBClassifier(random_state=1)
rfc =  			RandomForestClassifier(random_state=1,n_estimators=300)
cv  =		 	5


# Define some global variables.

# In[6]:


from os.path import expanduser, exists

PATH = expanduser("~/Datasets/")
UCR_path       = PATH +"UCR/"+ dataset + "/"
IMMUNO_path = PATH+ "1.5mmRegions/"
list_filts = ["Alpha-DTM-0", "Alpha-DTM-1"]


# Read the data sets, and impute the missing values.

# In[7]:


immuno = (dataset == "immuno")
UCR = not immuno


# In[8]:

print("Reading dataset...", flush=True)
if UCR:
    path = UCR_path
    X1 = np.array(pd.read_csv(path + dataset + "_TRAIN.tsv", sep="\t", header=None))
    X2 = np.array(pd.read_csv(path + dataset + "_TEST.tsv",  sep="\t", header=None))
    X = np.vstack([X1, X2])
    split = len(X1)
elif immuno:
    path = IMMUNO_path
    def get_regions():
        X, labels = [],[]
        for label in ["FoxP3", "CD8", "CD68"]:
    #     for label in ["FoxP3", "CD8"]:
            for root, dirs, files in os.walk(IMMUNO_path + label+"/"):
                for name in files:
                    X.append(np.array(read_csv(IMMUNO_path+label+"/"+name))/1500)
                    labels.append(label)
        return X, LabelEncoder().fit_transform(labels)
    TS, L =  get_regions()
    split = np.floor(len(TS) * 0.75).astype(int)

print("Dataset : ", path)
# In[9]:


if UCR:
    L, TS = X[:,0], X[:,1:]
    imp = SimpleImputer(missing_values=np.nan, strategy="mean")
    TS = imp.fit_transform(TS)
    tde = TimeDelayEmbedding(dim=dimension, delay=delay, skip=skip)
nts = len(TS)
np.savetxt(path + "labels.txt", L)


# # Decompositions

# Compute maximal pairwise distance for Alpha complex.

# In[10]:

print("Computing filtration bounds", flush=True)

ds = []
for tsidx in range(30):
    X = tde(TS[tsidx,:]) if UCR else TS[tsidx]
    ds.append(pairwise_distances(X).flatten())
allds = np.concatenate(ds)
maxd = np.mean(allds)


# Compute bounding rectangle for multiparameter persistence.

# In[11]:


mxf, Mxf, myf, Myf = np.inf, -np.inf, np.inf, -np.inf

for tsidx in range(0, nts):

    # Compute min and max of first filtration (Alpha)
    X = tde(TS[tsidx,:]) if UCR else TS[tsidx]
    st = gd.AlphaComplex(points=X).create_simplex_tree(max_alpha_square=maxd)
    fs = [f for (s,f) in st.get_filtration()]
    mxf, Mxf = min(mxf, min(fs)), max(Mxf, max(fs))
    
    # Compute min and max of second filtration (lower-star on DTM)
    if not KDE:
        fs = DTM(X, X, m)
        # density = DTM(X, X, m)
        # for (s,f) in st.get_skeleton(0):
        #     st.assign_filtration(s, density[s[0]])
        # for (s,f) in st.get_filtration():
        #     if len(s) > 1:
        #         st.assign_filtration(s, -1e10)
        # st.make_filtration_non_decreasing()
        # fs = [f for (s,f) in st.get_filtration()]
    if KDE:
        fs = KernelDensity(bandwidth=kde_bandwidth).fit(X).score_samples(X)
    myf, Myf = min(myf, min(fs)), max(Myf, max(fs))


# Compte all multipersistence decompositions.

# In[12]:


# ldgms0, mdgms0 = [], []
# ldgms1, mdgms1 = [], []
# count = 0

# for tsidx in tqdm(range(0, nts)):

#     # Compute time delay embedding and DTM density
#     X = tde(TS[tsidx,:])
#     density = np.squeeze(DTM(X, X, m))
    
#     # Create Alpha complex
#     dcomplex = gd.AlphaComplex(points=X)
#     st = dcomplex.create_simplex_tree(max_alpha_square=maxd)

#     # Use first barycentric subdivision to turn Alpha into a lower-star
#     list_splxs = []
#     st2 = gd.SimplexTree()
#     for (s,_) in st.get_filtration():
#         st2.insert(s, max([density[v] for v in s]))
#         list_splxs.append((s, max([density[v] for v in s])))
#     bary1 = barycentric_subdivision(st, use_sqrt=False)
#     bary2 = barycentric_subdivision(st2, list_splx=list_splxs)

#     # Write inputs for vineyards algorithm
#     cname, fname = path + "complex" + str(count) + ".txt", path + "filtrations" + str(count) + ".txt"
#     complexfile, filtfile = open(cname, "w"), open(fname, "w")
#     for (s,f) in bary1.get_filtration():
#         for v in s:
#             complexfile.write(str(v) + " ")
#         complexfile.write("\n")
#         if len(s) == 1:
#             filtfile.write(str(f) + " " + str(bary2.filtration(s)) + "\n")
#     complexfile.close()
#     filtfile.close()

#     # Compute the vineyards
#     mdg0, lines0, _, _ = sublevelsets_multipersistence(
#         "vineyards", cname, fname, homology=0, num_lines=nlines, corner="dg", extended=False, essential=False,
#         noise=noise, visu=False, plot_per_bar=False, 
#         bnds_filt=[mxf,Mxf,myf,Myf], bnds_visu=[mxf,Mxf,myf,Myf])
#     mdg1, lines1, _, _ = sublevelsets_multipersistence(
#         "vineyards", cname, fname, homology=1, num_lines=nlines, corner="dg", extended=False, essential=False,
#         noise=noise, visu=False, plot_per_bar=False, 
#         bnds_filt=[mxf,Mxf,myf,Myf], bnds_visu=[mxf,Mxf,myf,Myf])

#     mdgms0.append(mdg0)
#     mdgms1.append(mdg1)
#     count += 1

#     os.system("rm " + cname + "*")
#     os.system("rm " + fname + "*")


# Save the data.

# In[13]:


# np.save(path + "lines_Alpha-DTM-0", lines0)
# np.save(path + "lines_Alpha-DTM-1", lines1)
# np.save(path + "bnds_Alpha-DTM-0", np.array([mxf,Mxf,myf,Myf]))
# np.save(path + "bnds_Alpha-DTM-1", np.array([mxf,Mxf,myf,Myf]))
# pck.dump(mdgms0, open(path + "mdgms_Alpha-DTM-0.pkl", "wb"))
# pck.dump(mdgms1, open(path + "mdgms_Alpha-DTM-1.pkl", "wb"))


# Compute mma modules (lower star not needed)

# In[14]:


def compute_ts_mod(tsidx):
# Compute time delay embedding
    x = tde(TS[tsidx,:]) if UCR else TS[tsidx]
    kde = KernelDensity(bandwidth=kde_bandwidth).fit(x)
    x = np.unique(x, axis=0)

    # Create Alpha complex
    dcomplex = gd.AlphaComplex(points=x)
    st = dcomplex.create_simplex_tree(max_alpha_square=maxd)
    y = [dcomplex.get_point(i) for i, _ in enumerate(x)]

    # Compute density
    density_filtration = kde.score_samples(y)
    
    # Format input for mma
    boundary_matrix, alpha_filtration = mma.splx2bf(st)
    bifiltration = [np.array(alpha_filtration), density_filtration]
    
    # Computes the module approximation
    mod = mma.approx(boundary_matrix, bifiltration, box = [[mxf, myf], [Mxf, Myf]])
    return mod.dump() # dump to be pickle-able.

# Computes the module list if it does not exists
if (not exists(path + "mma_mods_for_multipers.pkl")) or (force_computation):
	print("Computing modules", flush=True)
	with Pool(processes=ncore) as pool:
		module_list = pool.map(compute_ts_mod, range(0, nts))
	pck.dump(module_list, open(path + "mma_mods_for_multipers.pkl", "wb"))
	del module_list
else:
	print("Skipping module computation, as file already exists", flush=True)


# # Vectorizations

# Read the data.

# In[16]:


# list_lines, list_bnds, list_mdgms, list_delta = [], [], [], []
# for filtname in list_filts:
#     lines = np.load(path + "lines_" + filtname + ".npy")
#     list_lines.append(lines)
#     delta = np.abs(lines[0,0]-lines[1,0]) if lines[0,0] != lines[1,0] else np.abs(lines[2,2]-lines[1,2])
#     list_delta.append(delta)
#     list_bnds.append(np.load(path + "bnds_" + filtname + ".npy"))
#     list_mdgms.append(pck.load(open(path + "mdgms_" + filtname + ".pkl", "rb")))    


# Compute Multiparameter Persistence Images.

# In[17]:




# Compute Multiparameter Persistence Landscapes.

# In[18]:


# for filtidx in range(len(list_filts)):
    
#     bnds  = list_bnds [filtidx]
#     mdgms = list_mdgms[filtidx]
#     delta = list_delta[filtidx]
#     filtname = list_filts[filtidx]
    
#     MLS = [multipersistence_landscape(mdg,bnds,delta,resolution=[res,res],k=5,return_raw=True) for mdg in mdgms]
#     pck.dump(MLS, open(path + "mls_" + str(res) + "_" + filtname + ".pkl", "wb"))
#     print("MLS done")


# Compute Multiparameter Persistence Kernels.

# In[19]:


# for filtidx in range(len(list_filts)):
    
#     lines = list_lines[filtidx]
#     bnds  = list_bnds [filtidx]
#     mdgms = list_mdgms[filtidx]
#     filtname = list_filts[filtidx]
    
#     MK  = [extract_diagrams(mdg, bnds, lines) for mdg in mdgms]
#     sw = sktda.SlicedWassersteinDistance(num_directions=10)
#     M = multipersistence_kernel(MK, MK, lines, sw, lambda x: 1, same=True, return_raw=False, power=0)
#     pck.dump(M, open(path + "mk_" + filtname + ".pkl", "wb"))
#     print("MK done")


# Collect the diagonal barcodes for 1D persistence.

# In[20]:


# fibs = []
# for filtidx in range(len(list_filts)):
    
#     lines = list_lines[filtidx]
#     bnds  = list_bnds [filtidx]
#     mdgms = list_mdgms[filtidx]
#     filtname = list_filts[filtidx]
    
#     ldgms = []
#     for decomposition in mdgms:
#         if len(decomposition) > 0:
#             mdgm = np.vstack(decomposition)
#             al = int(len(lines)/2)
#             for a in range(len(lines)):
#                 if lines[a,0] == min(lines[:,0]) and lines[a,1] == min(lines[:,1]):
#                     al = a
#                     break
#             dg = []
#             idxs = np.argwhere(mdgm[:,4] == al)[:,0]
#             if len(idxs) > 0:
#                 dg.append(mdgm[idxs][:,:4])
#             if len(dg) > 0:
#                 dg = np.vstack(dg)
#                 dg = intersect_boundaries(dg, bnds)
#                 if len(dg) > 0:
#                     xalpha, yalpha, xAlpha, yAlpha = lines[al,0], lines[al,1], lines[al,2], lines[al,3]
#                     pt = np.array([[xalpha, yalpha]])
#                     st, ed = dg[:,[0,2]], dg[:,[1,3]]
#                     dgm = np.hstack([ np.linalg.norm(st-pt, axis=1)[:,np.newaxis], 
#                                       np.linalg.norm(ed-pt, axis=1)[:,np.newaxis] ])
#                 else:
#                     dgm = np.array([[.5*(bnds[0]+bnds[1]), .5*(bnds[2]+bnds[3])]])
#             else:
#                 dgm = np.array([[.5*(bnds[0]+bnds[1]), .5*(bnds[2]+bnds[3])]])
#         else:
#             dgm = np.array([[.5*(bnds[0]+bnds[1]), .5*(bnds[2]+bnds[3])]])
#         ldgms.append(dgm)
#     fibs.append(ldgms)


# Compute persistence landscapes.

# In[21]:


# for filtidx in range(len(list_filts)):
    
#     fib = fibs[filtidx]
#     filtname = list_filts[filtidx]
    
#     ldgmsLS = [dg for dg in fib]
#     L = sktda.Landscape(num_landscapes=5,resolution=res*res,sample_range=[np.nan, np.nan]).fit_transform(ldgmsLS)
#     pck.dump(L, open(path + "ls_" + filtname + ".pkl", "wb"))
#     print("LS done")


# Compute persistence images.

# In[22]:


# for filtidx in range(len(list_filts)):
    
#     fib = fibs[filtidx]
#     filtname = list_filts[filtidx]
    
#     ldgmsPI = [dg for dg in fib]
#     ldgmsPI = [np.hstack([dgm[:,0:1], dgm[:,1:2]-dgm[:,0:1]]) for dgm in ldgmsPI]
#     PXs, PYs = np.vstack([dgm[:,0:1] for dgm in ldgms]), np.vstack([dgm[:,1:2] for dgm in ldgms])
#     bnds = [PXs.min(), PXs.max(), PYs.min(), PYs.max()]
#     PI = [persistence_image(dgm=dgm, bnds=bnds, return_raw=True) for dgm in ldgmsPI]
#     pck.dump(PI, open(path + "pi_" + str(res) + "_" + filtname + ".pkl", "wb"))
#     print("PI done")


# # Classifications

# Read the labels.

# In[23]:


labels = np.loadtxt(path + "labels.txt", dtype=float)
labels = LabelEncoder().fit_transform(np.array([int(l) for l in labels]))
npoints = len(labels)
if UCR:
    train_index, test_index = np.arange(split), np.arange(split, npoints)
elif immuno:
    from sklearn.model_selection import train_test_split
    train_index, test_index = train_test_split(range(len(labels)), test_size=0.25)


# In[24]:


if UCR:
    xtrainf = [x.flatten() for i,x in enumerate(TS) if i in train_index]
    xtestf = [x.flatten() for i,x in enumerate(TS) if i in test_index]
    rfc.fit(xtrainf, labels[train_index])
    print("RandomForest on UCR, with dataset", dataset, " : ",  rfc.score(xtrainf, labels[train_index]),rfc.score(xtestf, labels[test_index]))


# Classify MMA images

# In[25]:


print("Bounding box : ", mxf, myf, Mxf, Myf)
print("Resolution : ", res)

# In[ ]:

print("Loading MMA modules from file...")
Xmmai = pck.load(open(path+"mma_mods_for_multipers.pkl", "rb"))



# for filtidx in range(len(list_filts)):
#     bnds  = list_bnds [filtidx]
#     mdgms = list_mdgms[filtidx]
#     filtname = list_filts[filtidx]
#
#     MPI = [multipersistence_image(mdg, bnds, resolution=[res,res], return_raw=True) for mdg in mdgms]
#     pck.dump(MPI, open(path + "mpi_" + str(res) + "_" + filtname + ".pkl", "wb"))
#     print("MPI done")


bnds  = [mxf,Mxf ,myf,  Myf]
# print(path + f"mpi_{res}_{list_filts[0]}.pkl")
# exit()
if force_computation or not exists(path + f"mpi_{res}_{list_filts[-1]}.pkl"):
    print(list_filts)
    for i,filtname in enumerate(list_filts):
        print(f"Dimension {i}")
        mdgms = [mma.from_dump(mod).barcodes(num=nlines, threshold=True, dimension=i).to_multipers() for mod in Xmmai]
        with Parallel(n_jobs=ncore) as p:
            MPI = p(delayed(multipersistence_image)(mdg, bnds, resolution=[res,res], return_raw=True) for mdg in tqdm(mdgms))
        pck.dump(MPI, open(path + "mpi_" + str(res) + "_" + filtname + ".pkl", "wb"))
        print("MPI done", flush=True)
else:
    print("Skipping module computation")

del Xmmai

Xmpi = [pck.load(open(path + "mpi_" + str(res)  + "_" + filt + ".pkl", "rb")) for filt in list_filts]
params_mpi = {
    "mpi__bdw":     [1e-2, 1e-1, 1, 1e1, 1e2],
    "mpi__power":   [0, 1],
    "mpi__step":    [1, 5],
    # "clf":          [rfc],
}
pipe_mpi = Pipeline([("mpi", MultiPersistenceImageWrapper()), ("clf", xgbc)])
X_train  = [[Xmpi[nf][n] for nf in range(len(Xmpi))] for n in train_index]
X_test    = [[Xmpi[nf][n] for nf in range(len(Xmpi))] for n in test_index]
y_train, y_test = labels[train_index], labels[test_index]
model = GridSearchCV(estimator=pipe_mpi, param_grid=params_mpi, cv=cv, verbose=1, n_jobs=ncore)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
print(f"MP-I train accuracy {dataset} = ", model.score(X_train, y_train))
print(f"MP-I test accuracy {dataset} = ", model.score(X_test, y_test))
print(f"MP-I best params {dataset} = ", model.best_params_)

pck.dump([model.best_params_, model.cv_results_, score],
         open(path + "modelMPI_CV" + str(cv) + ".pkl", "wb"))




