"""
Adaptation of the code delivered by authors dedicated to OGB datasets
https://openreview.net/forum?id=AAes_3W-2z
https://github.com/navid-naderi/WEGL
"""
from collections import defaultdict
from tqdm.notebook import tqdm
import itertools
import matplotlib.pyplot as plt

import numpy as np
import torch 
import os
import pandas as pd
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader, Data
from torch_geometric.transforms import OneHotDegree
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
import ot
from sklearn.metrics import accuracy_score
from WEGL.diffusion import Diffusion
from sklearn.model_selection import train_test_split 
#%%

dataset_name = 'mutag'
device='cpu'
dtype=torch.float32
random_seed = 0
torch.manual_seed(0)
np.random.seed(0)
num_pca_components=-1
#num_pca_components=20, # number of PCA components applied on node embeddings. -1 means no PCA.
list_split_trainval_seed = list(range(10))
classifiers = ['RF']
#%%


dataset_to_repo = {'mutag':'MUTAG',
                   'ptc':'PTC_MR',
                   'enzymes':'ENZYMES',
                   'protein':'PROTEINS_full',
                   'nci1':'NCI1',
                   'imdb-b':'IMDB-BINARY',
                   'imdb-m':'IMDB-MULTI',
                   'collab':'COLLAB'}
if dataset_name in ['imdb-b', 'imdb-m', 'collab']:
    dataset = TUDataset(root='/tmp/%s'%dataset_to_repo[dataset_name], name=dataset_to_repo[dataset_name], transform = OneHotDegree(max_degree=500), use_edge_attr=False)
else:
    dataset = TUDataset(root='/tmp/%s'%dataset_to_repo[dataset_name], name=dataset_to_repo[dataset_name], use_edge_attr=False)

dataset.data
indices = np.array(dataset.indices())
labels = dataset.data.y.numpy()
dataset_size = len(indices)
split_traintest_seed = 0

list_num_hidden_layers = [1, 2, 3, 4]
list_final_node_embedding = ['concat', 'final']
abspath = os.path.abspath('../')
res_repo = abspath+'/results_WEGL/%s/'%dataset_name
if not os.path.exists(res_repo):
    os.makedirs(res_repo)

for num_hidden_layers in list_num_hidden_layers:
    for final_node_embedding in list_final_node_embedding:
        for split_trainval_seed in tqdm(list_split_trainval_seed, desc='splits'):
        
            experiment_name ='/WEGL_L%s_%s_traintest%s_val%s'%(num_hidden_layers, final_node_embedding, split_traintest_seed, split_trainval_seed)
            file_results = '%s/%s_res.csv'%(res_repo, experiment_name)
            if not os.path.exists(file_results ):
                learn = True
            else:
                learn= False
            if not learn:
                continue
            else:
                results = {
                    'clf':[],
                    'val_accuracy':[],
                    'test_accuracy':[]}
                print('experiment_name:', experiment_name)
            
                #split dataset
                idx_train, idx_test, y_train, y_test = train_test_split(np.arange(dataset_size), labels, test_size=0.1, stratify=labels, random_state=split_traintest_seed)
                idx_subtrain, idx_val, y_subtrain, y_val = train_test_split(np.arange(len(idx_train)), y_train, test_size=0.1, stratify=y_train, random_state=split_trainval_seed)
                true_idx_subtrain = idx_train[idx_subtrain]
                true_idx_val = idx_train[idx_val]
                loader_dict = {}
                idx_phases = [torch.tensor(true_idx_subtrain), torch.tensor(true_idx_val), torch.tensor(idx_test)]
                for nphase, phase in enumerate(idx_phases):
                    
                    batch_size = 32
                    loader_dict[nphase] = DataLoader(dataset[phase], batch_size=batch_size, shuffle=False)
                
                
                
                n_jobs = 6
                verbose = 0
                L = num_hidden_layers
                # create an instance of the diffusion object
                diffusion = Diffusion(num_hidden_layers=L,
                                      final_node_embedding=final_node_embedding).to(device)
                diffusion.eval()
                
                # create the node encoder
                # not used for TUDataset
                #node_feature_encoder = AtomEncoder(F).to(device=device, dtype=dtype)
                #node_feature_encoder.eval()
                
                phases = list(loader_dict.keys()) # determine different partitions of data ('train', 'valid' and 'test')
                
                # pass the all the graphs in the data through the GNN
                X = defaultdict(list)
                Y = defaultdict(list)
                
                for phase in phases:
                    #print('Now diffusing the %s data '%phase)
                    for i, batch in enumerate(loader_dict[phase]):
                        new_edge_attr = torch.ones((batch.edge_index.shape[1], 1), device=device, dtype=dtype)
                        batch_ = Data(x= batch.x, edge_index=batch.edge_index, y=batch.y, edge_attr = new_edge_attr)
                        #print('batch.x=', batch.x, batch.x.dtype)
                        # encode node features
                        #batch.x = node_feature_encoder(batch.x)
                        
                        # encode edge features
                        #batch.edge_attr = BondEncoderOneHot(batch.edge_attr)
                        
                        # pass the data through the diffusion process
                        z = diffusion(batch_)
                
                        batch_indices = batch.batch.cpu()
                        for b in range(batch_size):
                            node_indices = np.where(batch_indices == b)[0]
                            X[phase].append(z[node_indices].detach().cpu().numpy())
                        
                        Y[phase].extend(batch.y.detach().cpu().numpy().flatten().tolist())
                        
                # standardize the features based on mean and std of the training data
                ss = StandardScaler()
                #ss.fit(np.concatenate(X['train'], 0))
                ss.fit(np.concatenate(X[0], 0))
                transX = {}
                transY = {}
                for phase in phases:
                    transX[phase] = []
                    transY[phase] = []
                    for i in range(len(X[phase])):
                        if X[phase][i].shape[0] != 0:
                            transX[phase].append(ss.transform(X[phase][i]))
                            transY[phase].append(Y[phase][i])
                            
                # apply PCA if needed
                if num_pca_components > 0:
                    #print('Now running PCA ...')
                    pca = PCA(n_components=num_pca_components, random_state=random_seed)
                    #pca.fit(np.concatenate(X['train'], 0))
                    pca.fit(np.concatenate(transX[0], 0))
                    for phase in phases:
                        for i in range(len(X[phase])):
                            transX[phase][i] = pca.transform(transX[phase][i])
                            
                    # plot the variance % explained by PCA components
                    plt.plot(np.arange(1, num_pca_components + 1), pca.explained_variance_ratio_, 'o--')
                    plt.grid(True)
                    plt.xlabel('Principal component')
                    plt.ylabel('Eigenvalue')
                    plt.xticks(np.arange(1, num_pca_components + 1, step=2))
                    plt.show()
                
                # number of samples in the template distribution
                #N = int(round(np.asarray([x.shape[0] for x in X['train']]).mean()))
                N = int(round(np.asarray([x.shape[0] for x in transX[0]]).mean()))
                
                # derive the template distribution using K-means
                #print('Now running k-means for deriving the template ...\n')
                kmeans = KMeans(n_clusters=N, verbose=verbose, random_state=random_seed)
                #kmeans.fit(np.concatenate(X['train'], 0))
                kmeans.fit(np.concatenate(transX[0], 0))
                template = kmeans.cluster_centers_
                
                # calculate the final graph embeddings based on LOT
                V = defaultdict(list)
                for phase in phases:
                    #print('Now deriving the final graph embeddings for the ' + phase + ' data ...')
                    for x in transX[phase]:
                        M = x.shape[0]
                        C = ot.dist(x, template)
                        b = np.ones((N,)) / float(N)
                        a = np.ones((M,)) / float(M)
                        p = ot.emd(a,b,C) # exact linear program
                        V[phase].append(np.matmul((N * p).T, x) - template)
                    V[phase] = np.stack(V[phase])
                
                # create the parameter grid for random forest
                param_grid_RF = {
                    'min_samples_leaf': [1, 2, 5],
                    'min_samples_split': [2, 5, 10],
                    'n_estimators': [25, 50, 100, 150, 200]
                }
                
                param_grid_all = {'RF': param_grid_RF}
                
                # run the classifier
                #print('Now running the classifiers ...')
                for classifier in classifiers:
                    if classifier not in param_grid_all:
                        #print('Classifier {} not supported! Skipping ...'.format(classifier))
                        continue
                    param_grid = param_grid_all[classifier]
                    for key in param_grid.keys():
                        results[key] = []
                    param_dict = {}
                    param_dict.update(param_grid)
                    for params in param_dict.keys():
                        param_keys, param_values = zip(*param_dict.items())
                    params  = [dict(zip(param_keys,v)) for v in itertools.product(*param_values)]
                    for exp in params:
                            
                        for key in exp.keys():
                            results[key].append(exp[key])
                        results['clf'].append(classifier)
                        if classifier == 'RF':
                            model = RandomForestClassifier(
                                n_estimators=exp['n_estimators'], min_samples_split=exp['min_samples_split'], 
                                min_samples_leaf=exp['min_samples_leaf'], 
                                n_jobs=n_jobs, class_weight='balanced', random_state=0)
                        model.fit(V[0].reshape(V[0].shape[0], -1), transY[0])
                        
                        pred_val = model.predict_proba(V[1].reshape(V[1].shape[0], -1))
                        pred_y_val = pred_val.argmax(1)
                        results['val_accuracy'].append(np.sum(np.array(transY[1])==pred_y_val)/pred_val.shape[0])
                        
                        pred_test = model.predict_proba(V[2].reshape(V[2].shape[0], -1))
                        pred_y_test = pred_test.argmax(1)
                        results['test_accuracy'].append(np.sum(np.array(transY[2])==pred_y_test)/pred_test.shape[0])
                        
        
                res_df = pd.DataFrame(results)
                res_df.to_csv(file_results)

