#!/usr/bin/env python
# coding: utf-8

# In[1]:


import logging
import argparse
import os
import numpy as np
import torch

from main.pytorch import PPNP
from main.pytorch.training import train_model
from main.pytorch.earlystopping import stopping_args
from main.pytorch.propagation_ import OurPropagate

from main.data.io import load_dataset
from main.pytorch.geometric_baselines import (
    ResConvModel_via_matmul,
    ResConv_predict_first_via_matmul,
)
from main.pytorch.resconv_sing import resconv_sing





# In[2]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


logging.basicConfig(
    format="%(asctime)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO + 2,
)


def run(graph_name, model_type, drop_prob=0.4, filter_num=64, learning_rate=0.1, reg_lambda=5e-4, test=False):
  
    graph = load_dataset(graph_name)
    make_undirected = True
    graph.standardize(make_undirected=make_undirected, select_lcc=True)

    

    test_seeds = [
        2144199730,
        794209841,
        2985733717,
        2282690970,
        1901557222,
        2009332812,
        2266730407,
        635625077,
        3538425002,
        960893189,
        497096336,
        3940842554,
        3594628340,
        948012117,
        3305901371,
        3644534211,
        2297033685,
        4092258879,
        2590091101,
        1694925034,
    ]
    val_seeds = [
        2413340114,
        3258769933,
        1789234713,
        2222151463,
        2813247115,
        1920426428,
        4272044734,
        2092442742,
        841404887,
        2188879532,
        646784207,
        1633698412,
        2256863076,
        374355442,
        289680769,
        4281139389,
        4263036964,
        900418539,
        119332950,
        1628837138,
    ]

    if test:
        seeds = test_seeds
    else:
        seeds = val_seeds

    if graph_name == "ms_academic":
        nknown = 5000
        ntrain_per_class = 20
        nstopping = 500
    elif graph_name in ("cornell", "texas", "wisconsin"):
        nknown = 150
        ntrain_per_class = 6
        nstopping = 50
    else:
        nknown = 1500
        ntrain_per_class = 20
        nstopping = 500

    idx_split_args = {
        "ntrain_per_class": ntrain_per_class,
        "nstopping": nstopping,
        "nknown": nknown,
    }

    device = "cuda"

  
    if issubclass(model_type, ResConvModel_via_matmul):
       
        omega = -0.25
        nf_minus = 0.01
       
        model_args = {
            "adj_matrix": graph.adj_matrix,
            "K_minus": 1,
            "singular_value_normalization": True,
            "zero_order": True,
            "hidden_channel_list": [filter_num],
            "omega": omega,
            "normalizing_factor_minus": nf_minus,
            "resolvent_path": f"./main/data/{graph_name}_{omega}_{nf_minus}_resolvent_mat.npy",
            'dropout': drop_prob,
    
        }
    elif issubclass(model_type, ResConv_predict_first_via_matmul):
 
        omega = -0.15
        nf_minus = 0.01
        model_args = {
            "adj_matrix": graph.adj_matrix,
            "zero_order": True,
            "K_minus": 1,
            "singular_value_normalization": True,
            "hidden": filter_num,
            "omega": omega,
            "dropout": 0.3,
            "dprate": drop_prob,
            "normalizing_factor_minus": nf_minus,
            "bias": True,
            "resolvent_path": f"./main/data/{graph_name}_{omega}_{nf_minus}_resolvent_mat.npy",
        }

    elif issubclass(model_type, resconv_sing):
        normalizing_factor = 2
        omega = -0.14

        prop_ppnp = OurPropagate(
            graph.adj_matrix, omega=omega, normalizing_factor=normalizing_factor,
            resolvent_path = f"./main/data/{graph_name}_{omega}_{normalizing_factor}_resolvent_mat_ppnp_laplacian.npy"
        )
        propagation_method = prop_ppnp

        model_args = {
            "hiddenunits": [filter_num],
            "drop_prob": drop_prob,
            "propagation": propagation_method,
 
        }







    if graph_name in ("cornell", "texas", "wisconsin"):
        splits_path = f"./main/data/{graph_name}_splits.npz"
        niter_per_seed = 10
     
    else:
        splits_path = None
        niter_per_seed = 5

    save_result = False
    print_interval = 100
    device = "cuda"

    results = []
    niter_tot = niter_per_seed * len(seeds)
    i_tot = 0
    for seed in seeds:
        idx_split_args["seed"] = seed
        for split_idx in range(niter_per_seed):
            i_tot += 1
            _, result = train_model(
                graph_name,
                model_type,
                graph,
                model_args,
                learning_rate,
                reg_lambda,
                idx_split_args,
                stopping_args,
                test,
                device,
                seed + split_idx,
                print_interval,
                splits_path,
                idx_standard_split=split_idx,
            )
            results.append({})
            results[-1]["stopping_accuracy"] = result["early_stopping"]["accuracy"]
            results[-1]["valtest_accuracy"] = result["valtest"]["accuracy"]
            print(f"Val. acc. during grid search: {result['valtest']['accuracy']}")
            results[-1]["runtime"] = result["runtime"]
            results[-1]["runtime_perepoch"] = result["runtime_perepoch"]
            results[-1]["split_seed"] = seed

    import pandas as pd
    import seaborn as sns

    result_df = pd.DataFrame(results)
    result_df.head()

    def calc_uncertainty(values: np.ndarray, n_boot: int = 1000, ci: int = 95) -> dict:
        stats = {}
        stats["mean"] = values.mean()
        boots_series = sns.algorithms.bootstrap(values, func=np.mean, n_boot=n_boot)
        stats["CI"] = sns.utils.ci(boots_series, ci)
        stats["uncertainty"] = np.max(np.abs(stats["CI"] - stats["mean"]))
        return stats


    stopping_acc = calc_uncertainty(result_df["stopping_accuracy"])
    valtest_acc = calc_uncertainty(result_df["valtest_accuracy"])
    runtime = calc_uncertainty(result_df["runtime"])
    runtime_perepoch = calc_uncertainty(result_df["runtime_perepoch"])


    if(test):
        print(
            "\n"
            "Early stopping: Accuracy: {:.2f} ± {:.2f}%\n"
            "{}: Accuracy: {:.2f} ± {:.2f}%\n"
            "Runtime: {:.3f} ± {:.3f} sec, per epoch: {:.2f} ± {:.2f}ms".format(
                stopping_acc["mean"] * 100,
                stopping_acc["uncertainty"] * 100,
                "Test" if test else "Validation",
                valtest_acc["mean"] * 100,
                valtest_acc["uncertainty"] * 100,
                runtime["mean"],
                runtime["uncertainty"],
                runtime_perepoch["mean"] * 1e3,
                runtime_perepoch["uncertainty"] * 1e3,
            )
        )

    return valtest_acc['mean'] * 100, valtest_acc['uncertainty'] * 100

#note that we use the name cora_ml for Cora and ms_academic for MS. Acad.
for graph_name in ['citeseer']:
    for model_type in [
        resconv_sing, 
        ResConvModel_via_matmul, 
        ResConv_predict_first_via_matmul]:
        print("\n Running : ", model_type, graph_name)
        best_acc = 0
        best_lr = 0
        best_reg_lambda = 0
        best_drop_prob = 0
        best_filter_num=0
        best_uncertainty = 0
        for drop_probability in [0.3, 0.35, 0.4, 0.45, 0.5]:
            for reg_lambda in [1e-4, 5e-4]:
                for filter_num in [64,128]:
                    for lr in [0.1]:
                        acc, unc = run(graph_name, model_type, drop_prob=drop_probability, learning_rate=lr, reg_lambda=reg_lambda, test=False)
                        if(acc>best_acc):
                            best_acc=acc
                            best_lr = lr
                            best_reg_lambda = reg_lambda
                            best_drop_prob = drop_probability
                            best_filter_num=filter_num
                            best_uncertainty = unc
        print('best config : ',  model_type, graph_name, best_acc, best_drop_prob, best_filter_num, best_lr, best_reg_lambda, best_uncertainty)
        
        test_acc, test_unc =  run(graph_name, model_type, drop_prob=best_drop_prob, learning_rate=best_lr, reg_lambda=best_reg_lambda, test=True)
        print('\n test accuracy : ', test_acc, graph_name, model_type, test_unc, best_acc)

