import numpy as np

from sklearn.metrics.pairwise import rbf_kernel
from functools import partial

from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.model_selection import ParameterSampler

import pandas

from joblib import Parallel, delayed

from time import time
from Sketch import SubSample

import argparse
import os
import logging
import matplotlib.pyplot as plt


from utility import (
    eval_graph,
    eval_graph_parallel,
    to_networkx,
    fngw_barycentre_func,
    tab_printer,
    plot_qm9,
)


class FNGWEstimator:
    def __init__(
        self,
        gaussien_gamma=None,
        fngw_alpha=0.33,
        fngw_beta=0.33,
        reg=0.1,
        n_bary=10,
        use_sketch=True,
        mu=1e-8,
        s_dim=1200,
        logger=None,
    ):
        self.M = None
        self.X_tr = None
        self.Y_tr = None
        self.S = None
        self.alpha = fngw_alpha
        self.beta = fngw_beta
        self.kernel = partial(rbf_kernel, gamma=gaussien_gamma)
        self.L = reg
        self.n_bary = n_bary
        self.sketch_dim = s_dim
        self.mu = mu
        self.sketch = use_sketch
        self.logger = logger

    def fit(self, X, Y):
        self.X_tr = X
        n_tr = X.shape[0]
        # K = self.kernel(X, X)

        if not self.sketch:
            K = self.kernel(X, X)
            self.M = np.linalg.inv(K + n_tr * self.L * np.eye(n_tr))
        else:
            t0 = time()
            self.S = SubSample(size=(self.sketch_dim, n_tr))
            s = self.S.size[0]
            KST = self.S.multiply_Gram_one_side(X, self.kernel, X)
            SKST = self.S.multiply_matrix_one_side(KST, right=False)
            B = KST.T.dot(KST) + n_tr * self.L * SKST
            B_inv = np.linalg.inv(B + self.mu * np.eye(s))
            self.M = B_inv.dot(KST.T)
            fit_time = time() - t0
            self.logger.info(f"Fitting time: {fit_time} s")
        self.Y_tr = Y

    def pred(self, X, N_true=None):
        if not self.sketch:
            K_tr_te = self.kernel(self.X_tr, X)
            A = self.M.dot(K_tr_te)
        else:
            SK_x_tr_te = self.S.multiply_Gram_one_side(X, self.kernel, self.X_tr)
            # Estimate and Compute scores <Ph(x) | psi(y_c)>
            A = SK_x_tr_te.dot(self.M).T

        N_preds = N_true

        # Compute n_te barycenters with n_bary points max for each barycenter
        tic = time()
        y_preds = [
            fngw_barycentre_func(
                Y_tr=self.Y_tr,
                lambdas=lambda_,
                alpha=self.alpha,
                beta=self.beta,
                n_bary=self.n_bary,
                N=N_pred,
            )
            for lambda_, N_pred in tqdm(
                zip(
                    A.T,
                    N_preds,
                ),
                total=len(N_preds),
            )
        ]
        tac = time()
        print(f"{(tac-tic)/60} min")

        return y_preds


def _fit_and_score(
    X_train_cv,
    y_train_cv,
    X_valid_cv,
    y_valid_cv,
    logger,
    alpha,
    beta,
    gamma,
    reg,
    n_bary,
    s_dim,
):
    estimator = FNGWEstimator(
        gamma,
        alpha,
        beta,
        reg,
        n_bary,
        use_sketch=True,
        s_dim=s_dim,
        logger=logger,
    )

    estimator.fit(X_train_cv, y_train_cv)
    N_valid_cv = [single["F"].shape[0] for single in y_valid_cv]

    y_pred_valid = estimator.pred(X=X_valid_cv, N_true=N_valid_cv)


    G_preds = [to_networkx(y_pred, use_edge_feature=True) for y_pred in y_pred_valid]
    G_trgts = [to_networkx(y_target, use_edge_feature=True) for y_target in y_valid_cv]

    metric = eval_graph(G_preds, G_trgts, with_edge_feature=True)

    res = {
        "alpha": alpha,
        "beta": beta,
        "gamma": gamma,
        "reg": reg,
        "s_dim": s_dim,
        "edit_distance": metric["edit_distance"],
        "n_bary": n_bary,
    }
    return res


def cross_validation(
    X_train,
    y_train,
    save_dir,
    valid_size=1000,
    cv_trail=64,
    logger=None,
    n_jobs=32,
):
    # Get validation set
    X_train_cv, X_valid_cv, y_train_cv, y_valid_cv = train_test_split(
        X_train, y_train, test_size=valid_size, random_state=42
    )

    tic = time()

    # Get hyper parameter list

    param_grid = {
        "alpha": [0.01, 0.1, 0.5, 0.9, 0.99],
        "beta": [0.],
        "gamma": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
        "reg": [1e-10, 1e-8, 1e-6, 1e-4, 1e-2],
        "n_bary": [15, 20, 25],
        "s_dim": [1000, 5000],
    }

    rng = np.random.RandomState(42)
    param_list = list(ParameterSampler(param_grid, n_iter=cv_trail, random_state=rng))
    logger.info(param_list)

    parallel = Parallel(n_jobs=n_jobs, verbose=50)
    valid_res = parallel(
        delayed(_fit_and_score)(
            X_train_cv,
            y_train_cv,
            X_valid_cv,
            y_valid_cv,
            logger=logger,
            alpha=params["alpha"],
            beta=params["beta"],
            gamma=params["gamma"],
            reg=params["reg"],
            n_bary=params["n_bary"],
            s_dim=params["s_dim"],
        )
        for params in param_list
    )

    tac = time()
    logger.info(f"Validation time: {(tac - tic)/60} mins")

    sorted_valid_res = sorted(valid_res, key=lambda res: res["edit_distance"])

    df = pandas.DataFrame.from_records(sorted_valid_res)
    df.to_csv(os.path.join(save_dir, "valid_res_fngw.csv"))
    best_param = sorted_valid_res[0]
    logger.info(f"Best hyperparmeter{best_param}")
    return best_param


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--test", action="store_true")
    parser.add_argument("--n_bary", type=int, default=5)
    parser.add_argument("--gamma", type=float, default=0.001)
    parser.add_argument("--alpha", type=float, default=0.9)
    parser.add_argument("--reg", type=float, default=0.0001)
    parser.add_argument("--s_dim", type=int, default=500)
    parser.add_argument("--save_dir", type=str, default="./exper")
    parser.add_argument("--valid_size", type=int, default=1000)
    parser.add_argument("--cv_trail", type=int, default=64)
    parser.add_argument("--n_jobs", type=int, default=64)
    parser.add_argument("--data_dir", type=str, default="./data")
    parser.add_argument("--rand_seed", type=int, default=42)

    args = parser.parse_args()

    # fix random seed
    np.random.seed(args.rand_seed)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logging.basicConfig(
        filename=os.path.join(args.save_dir, "logs.txt"),
        filemode="a",
        format="%(asctime)s - %(message)s",
        datefmt="%d-%b-%y %H:%M:%S",
        level=logging.INFO,
    )

    tprint = tab_printer(args)
    logging.info(tprint)

    X_tr_tmp = np.load(
        os.path.join(args.data_dir, "X_fingerprint_train_qm9.npy"), allow_pickle=True
    )
    y_tr_tmp = np.load(
        os.path.join(args.data_dir, "y_train_qm9.npy"), allow_pickle=True
    )

    X_te_tmp = np.load(
        os.path.join(args.data_dir, "X_fingerprint_test_qm9.npy"), allow_pickle=True
    )
    y_te_tmp = np.load(os.path.join(args.data_dir, "y_test_qm9.npy"), allow_pickle=True)

    X = np.concatenate((X_tr_tmp, X_te_tmp))
    y = np.concatenate((y_tr_tmp, y_te_tmp))
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=2000, random_state=args.rand_seed
    )

    logging.info(f"Training size is {X_train.shape[0]}")
    logging.info(f"Test size is {X_test.shape[0]}")

    if not args.test:
        best_param = cross_validation(
            X_train,
            y_train,
            args.save_dir,
            args.valid_size,
            args.cv_trail,
            logging,
            n_jobs=args.n_jobs,
        )
    else:
        best_param = {
            "alpha": args.alpha,
            "beta": 0.0,
            "gamma": args.gamma,
            "reg": args.reg,
            "s_dim": args.s_dim,
            "n_bary": args.n_bary,
            "thres": args.thres,
        }
    # Retrain the model
    best_estimator = FNGWEstimator(
        best_param["gamma"],
        best_param["alpha"],
        best_param["beta"],
        best_param["reg"],
        best_param["n_bary"],
        use_sketch=True,
        s_dim=best_param["s_dim"],
        logger=logging,
    )
    best_estimator.fit(X_train, y_train)
    N_test = [single["F"].shape[0] for single in y_test]

    tic = time()
    y_pred_test = best_estimator.pred(X=X_test, N_true=N_test)
    with open(os.path.join(args.save_dir, 'y_pred_test.npy'), mode='wb') as f:
        np.save(f, y_pred_test)
    tac = time()
    logging.info(f"Predicting time: {tac - tic}s")

    G_preds = [to_networkx(y_pred, use_edge_feature=True) for y_pred in y_pred_test]
    G_trgts = [to_networkx(y_target, use_edge_feature=True) for y_target in y_test]

    test_metric = eval_graph_parallel(G_preds, G_trgts, with_edge_feature=True, n_jobs=args.n_jobs)
    eds = test_metric["eds"]
    logging.info(
        f"Edit distance on Test (take edge feature into account): {test_metric['edit_distance']}",
    )
    test_metric = eval_graph_parallel(G_preds, G_trgts, with_edge_feature=False, n_jobs=args.n_jobs)
    logging.info(f"Edit distance on Test (no edge feature): {test_metric['edit_distance']}")
    eds_no_ef = test_metric["eds"]

    i = 0
    for y_pred, y_tgt, edist in zip(y_pred_test[:40], y_test[:40], eds_no_ef[:40]):
        G_pred = to_networkx(y_pred, use_edge_feature=True, thres=0.5)
        fig, axs = plt.subplots(1, 1, figsize=(3 * 1, 3 * 1))
        plot_qm9(G_pred, axs, f"GED w/o EF = {edist}", draw_edge_feature=True)
        fig.tight_layout()
        fig.savefig(os.path.join(args.save_dir, f"ile_fngw_pred{i}.pdf"))

        G_tgt = to_networkx(y_tgt, use_edge_feature=True, thres=0.5)
        fig, axs = plt.subplots(1, 1, figsize=(3 * 1, 3 * 1))
        plot_qm9(G_tgt, axs, draw_edge_feature=True)
        fig.tight_layout()
        fig.savefig(os.path.join(args.save_dir, f"ile_fngw_tgt{i}.pdf"))
        i += 1
