import argparse
import os
import logging
from multiprocessing import Pool
import functools
import itertools as it
import socket

import networkx as nx
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.metrics import accuracy_score
from scipy.sparse.csgraph import shortest_path
import ot
from ot.gromov import fused_gromov_wasserstein2
from ot.utils import dist
from tqdm import tqdm
from torch_geometric.datasets import TUDataset

from fngw import fused_network_gromov_wasserstein2


def wl_labeling(F, A, iterations=3):

    def weisfeiler_lehman_step(G, labels):
        new_labels = {}
        for node in G.nodes():
            label_list = []
            for nbr in G.neighbors(node):
                label_list.append(labels[nbr])
            label = labels[node] + "".join(sorted(label_list))
            new_labels[node] = label
        return new_labels

    G = nx.from_numpy_array(A, create_using=nx.DiGraph)

    # set initial node labels
    node_lables_all = [[str(dd)] for dd in F]
    node_labels = {u: str(dd) for u, dd in enumerate(F)}

    for _ in range(iterations):
        node_labels = weisfeiler_lehman_step(G, node_labels)
        for node, label in node_labels.items():
            node_lables_all[node].append(label)

    return np.array(node_lables_all)


def nested_cv(X,
              y,
              estimator,
              param_grid,
              num_trails=10,
              n_inner=10,
              n_jobs=1,
              logging=None,
              kernel_params=None):
    nested_scores = []
    for i in tqdm(range(num_trails)):
        X_train, X_test, y_train, y_test = train_test_split(X,
                                                            y,
                                                            test_size=0.1,
                                                            random_state=i)
        logging.info(f"Test: {X_test}")
        cv = StratifiedKFold(n_splits=n_inner)  # No need to further shuffle
        for train_index, test_index in cv.split(X_train, y_train):
            logging.info(f"TRAIN: {train_index}")
            logging.info(f"VALID: {test_index}")

        clf = GridSearchCV(estimator=estimator,
                           param_grid=param_grid,
                           cv=cv,
                           n_jobs=n_jobs,
                           verbose=0)
        clf.fit(X_train, y_train)
        best_kernel = clf.best_params_['kernel']
        best_kernel_index = param_grid['kernel'].index(best_kernel)
        if kernel_params is not None:
            best_kernel_param = kernel_params[best_kernel_index]
        else:
            best_kernel_param = {}

        best_parms = {**best_kernel_param, **clf.best_params_}
        logging.info(f"BEST PARAMS: {best_parms}")
        y_pred = clf.predict(X_test)
        nested_scores.append(accuracy_score(y_test, y_pred))

    return np.mean(nested_scores), np.std(nested_scores), nested_scores


def get_fngw_matrix(alpha_betas_wl, X):
    alpha = alpha_betas_wl[0]
    beta = alpha_betas_wl[1]
    wl = alpha_betas_wl[2]
    Fs1 = [x['F'] for x in X]
    Fs2 = [x['F'] for x in X]

    Cs1 = [x['C'] for x in X]
    Cs2 = [x['C'] for x in X]

    As1 = [x['A'] for x in X]
    As2 = [x['A'] for x in X]

    Aadjs = [x['A_adj'] for x in X]

    if wl != 0:
        Fs1 = [wl_labeling(np.argmax(F, axis=-1), A_adj, iterations=wl) for (F, A_adj) in zip(Fs1, Aadjs)]
        Fs2 = Fs1
        node_metric = 'hamming'
    else:
        node_metric='sqeuclidean'

    dist_mat = np.zeros((len(X), len(X)))
    for i in tqdm(range(len(X))):
        for j in range(len(X)):
            if j >= i:
                C1 = Cs1[i]
                F1 = Fs1[i]
                A1 = As1[i]

                C2 = Cs2[j]
                F2 = Fs2[j]
                A2 = As2[j]

                n1 = C1.shape[0]
                n2 = C2.shape[0]
                p = ot.unif(n1)
                q = ot.unif(n2)
                if node_metric == 'sqeuclidean':
                    M = dist(F1, F2)
                elif node_metric == 'hamming':
                    M = cdist(F1, F2, metric='hamming')**2
                else:
                    raise ValueError
                fngw_dist, log = fused_network_gromov_wasserstein2(
                    M,
                    C1,
                    C2,
                    A1,
                    A2,
                    p,
                    q,
                    dist_fun_C='l2_norm',
                    dist_fun_A='square_loss',
                    alpha=alpha,
                    beta=beta,
                    numItermax=100,
                    stopThr=1e-5,
                    verbose=False,
                    log=True)
                dist_mat[i, j] = fngw_dist
                dist_mat[j, i] = fngw_dist

    return dist_mat


def get_fgw_matrix(alpha_wl, X):
    print("get into fgw")
    alpha = alpha_wl[0]
    wl = alpha_wl[1]

    Fs1 = [x['F'] for x in X]
    Fs2 = [x['F'] for x in X]

    As1 = [x['A'] for x in X]
    As2 = [x['A'] for x in X]

    Aadjs = [x['A_adj'] for x in X]

    if wl != 0:
        Fs1 = [wl_labeling(np.argmax(F, axis=-1), A_adj, iterations=wl) for (F, A_adj) in zip(Fs1, Aadjs)]
        Fs2 = Fs1
        node_metric = 'hamming'
    else:
        node_metric='sqeuclidean'

    dist_mat = np.zeros((len(X), len(X)))
    for i in tqdm(range(len(X))):
        for j in range(len(X)):
            if j >= i:
                A1 = As1[i]
                F1 = Fs1[i]
                A2 = As2[j]
                F2 = Fs2[j]
                n1 = A1.shape[0]
                n2 = A2.shape[0]
                p = ot.unif(n1)
                q = ot.unif(n2)
                if node_metric == 'sqeuclidean':
                    M = dist(F1, F2)
                elif node_metric == 'hamming':
                    M = cdist(F1, F2, metric='hamming')**2
                else:
                    raise ValueError
                fgw_dist, log = fused_gromov_wasserstein2(
                    M,
                    A1,
                    A2,
                    p,
                    q,
                    loss_fun='square_loss',
                    alpha=alpha,
                    numItermax=100,
                    stopThr=1e-5,
                    verbose=False,
                    log=True)
                dist_mat[i, j] = fgw_dist
                dist_mat[j, i] = fgw_dist

    return dist_mat


def kernel_customed(gamma, kernel_matrix):

    def _kernel(X1, X2):
        gram_mat = kernel_matrix[X1][:, X2]
        gram_mat = np.exp(-gamma * gram_mat)

        return gram_mat

    return _kernel


def main(args):
    np.random.seed(args.random_seed)
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    logging.basicConfig(format='%(asctime)s - %(message)s',
                        level=logging.INFO,
                        filename=os.path.join(args.log_path, 'log.txt'))

    logging.info(f'Args: {args}')
    host_name = socket.gethostname()
    logging.info(f'Runing machine: {host_name}')

    dataset = TUDataset(root=args.root_dir,
                        name=args.dataset,
                        use_node_attr=True,
                        use_edge_attr=True)
    X_raw = []
    y = []

    logging.info(f"Data sample: {dataset[0]}")

    if args.dataset in [
            'BZR_MD', 'COX2_MD', 'DHFR_MD', 'ER_MD', 'Cuneiform'
    ]:
        dim_edge = dataset[0].edge_attr.detach().cpu().numpy().shape[1]
        random_neutre = np.random.randn(dim_edge)
        logging.info(f"Feature vector for no edge: {random_neutre}")

    for data in dataset:
        num_nodes = len(data.x)

        edge_index = data.edge_index.detach().cpu().numpy().T
        if args.dataset == 'Cuneiform':
            F = data.x.detach().cpu().numpy()
            edge_attr = data.edge_attr.detach().cpu().numpy()
            dim_edge = edge_attr.shape[1]
            C = np.broadcast_to(random_neutre,
                                (num_nodes, num_nodes, dim_edge)).copy()
            for index, attr in zip(edge_index, edge_attr):
                C[index[0], index[1]] = attr

        elif args.dataset == 'MUTAG':
            F = data.x.detach().cpu().numpy()  # one-hot
            edge_attr = data.edge_attr.detach().cpu().numpy()
            edge_attr = np.argmax(edge_attr, axis=-1) + 1
            one_hots = np.eye(5)
            C = np.zeros((num_nodes, num_nodes))
            for index, attr in zip(edge_index, edge_attr):
                C[index[0], index[1]] = attr
            C = one_hots[C.astype('int')]

        elif args.dataset == 'PTC_MR':
            F = data.x.detach().cpu().numpy()  # one-hot
            edge_attr = data.edge_attr.detach().cpu().numpy()
            edge_attr = np.argmax(edge_attr, axis=-1) + 1
            one_hots = np.eye(5)
            C = np.zeros((num_nodes, num_nodes))
            for index, attr in zip(edge_index, edge_attr):
                C[index[0], index[1]] = attr
            C = one_hots[C.astype('int')]

        elif args.dataset in ['BZR_MD', 'COX2_MD', 'DHFR_MD', 'ER_MD']:
            F = data.x.detach().cpu().numpy()
            edge_attr = data.edge_attr.detach().cpu().numpy()
            dim_edge = edge_attr.shape[1]
            C = np.broadcast_to(random_neutre,
                                (num_nodes, num_nodes, dim_edge)).copy()
            for index, attr in zip(edge_index, edge_attr):
                C[index[0], index[1]] = attr
        else:
            raise ValueError

        A_raw = np.zeros((num_nodes, num_nodes))
        for index in edge_index:
            A_raw[index[0], index[1]] = 1

        A = shortest_path(A_raw, directed=False, method='D')
        A[A == float('inf')] = 10 * np.max(A[A != float('inf')])

        X_raw.append({'C': C, 'F': F, 'A': A, 'A_adj': A_raw})
        y.append(data.y.item())

    if args.method == 'fngw':
        pool_func = functools.partial(get_fngw_matrix,
                                      X=X_raw)
    elif args.method == 'fgw':
        pool_func = functools.partial(get_fgw_matrix,
                                      X=X_raw)
    else:
        raise ValueError

    logging.info(f"dataset size: {len(X_raw)}")

    gammas = np.logspace(-10, 10, base=2, num=21)
    logging.info(f"gamma: {gammas}")

    # Parameters for distance

    # for acceleration

    if args.dataset in ['BZR_MD', 'COX2_MD', 'DHFR_MD', 'ER_MD']:
        wls = [0, 1, 2, 3]
    elif args.dataset in ['PTC_MR', 'MUTAG']:
        wls = [0, 1, 2, 3, 4]
    else:
        wls = [0]

    if args.method == 'fngw':
        ## Original 
        alphas = np.concatenate(
            ([0.0], np.geomspace(start=1e-2, stop=0.25, num=3),
             (0.5 - np.geomspace(start=1e-2, stop=0.25, num=3))[:-1][::-1],
             [0.5]))
        logging.info(f"alphas: {alphas}")
        betas = np.concatenate(
            ([0.0], np.geomspace(start=1e-2, stop=0.25, num=3),
             (0.5 - np.geomspace(start=1e-2, stop=0.25, num=3))[:-1][::-1],
             [0.5]))
        logging.info(f"alphas: {betas}")
        alpha_betas = list(it.product(alphas, betas))
        
        alpha_betas = [cof for cof in alpha_betas if (cof[0] + cof[1]) <= 1]

        alpha_beta_wls = []
        for wl in wls:
            for cof in alpha_betas:
                alpha_beta_wls.append((cof[0], cof[1], wl))
        
        print(len(alpha_beta_wls))
        print(alpha_beta_wls)

        with Pool(args.n_jobs) as p:
            all_kernel_matrix = p.map(pool_func, alpha_beta_wls)

        kernel_list = [
            kernel_customed(gamma, all_kernel_matrix[index])
            for (gamma, index) in it.product(gammas, range(len(alpha_beta_wls)))
        ]

        kernel_parms = [{
            'gamma': gamma,
            'alpha': alpha_beta_wl[0],
            'beta': alpha_beta_wl[1],
            'wl': alpha_beta_wl[2]
        } for (gamma, alpha_beta_wl) in it.product(gammas, alpha_beta_wls)]

    elif args.method == 'fgw':
        alphas = np.concatenate(
            ([0.0], np.geomspace(start=1e-4, stop=0.5, num=7),
             (1 - np.geomspace(start=1e-4, stop=0.5, num=7))[:-1][::-1], [1.0
                                                                          ]))
        
        alpha_wls = list(it.product(alphas, wls))


        logging.info(f"alphas: {alphas}")
        logging.info(f"wls: {wls}")
        logging.info(f"alpha_wls: {alpha_wls}")

        
        with Pool(args.n_jobs) as p:
            all_kernel_matrix = p.map(pool_func, alpha_wls)

        kernel_list = [
            kernel_customed(gamma, all_kernel_matrix[index])
            for (gamma, index) in it.product(gammas, range(len(alpha_wls)))
        ]

        kernel_parms = [{
            'gamma': gamma,
            'alpha': alpha_wl[0],
            'wl': alpha_wl[1]
        } for (gamma, alpha_wl) in it.product(gammas, alpha_wls)]

    else:
        raise ValueError

    param_grid = {
        'C': np.logspace(-7, 7, 15),
        'kernel': kernel_list,
        'random_state': [args.random_seed]
    }
    
    if args.method == 'fngw' and args.dataset == 'MUTAG':
        param_grid['C'] = np.logspace(-7, 6, 14)
    logging.info(f"other parmas: {param_grid}")

    fngw_clf = SVC()

    X = np.arange(len(X_raw))

    res = nested_cv(X,
                    y,
                    fngw_clf,
                    param_grid,
                    num_trails=args.n_trails,
                    n_inner=args.n_inner,
                    n_jobs=args.n_jobs,
                    logging=logging,
                    kernel_params=kernel_parms)
    logging.info(f"Nested scores: {res[2]}")
    logging.info(f"Mean: {res[0]}")
    logging.info(f"std: {res[1]}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("FNGW/FGW for graph classification")
    parser.add_argument("--root_dir", type=str)
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--method", type=str, default='fngw')

    parser.add_argument("--random_seed", type=int, default=42)

    parser.add_argument("--n_jobs", type=int, default=20)
    parser.add_argument("--n_trails", type=int, default=10)
    parser.add_argument("--n_inner", type=int, default=10)

    parser.add_argument("--log_path", type=str)

    args = parser.parse_args()

    main(args)
