import random
import argparse
import torch
import numpy as np
from torch_geometric.utils import homophily

from model_opt import Trainer
from utils import load_graph, process_data, print_outcome

import warnings
warnings.filterwarnings('ignore')


def main(args):
    # set seed
    seed = 44
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # init graph
    graph_orig = load_graph(args)
    graph = graph_orig.clone()

    # init outputs
    train_accs, val_accs, test_accs, test_corrs = [], [], [], []
    sp_list, eo_list, hc_list = [], [], []

    """
    graph stat eval
    """
    hc = homophily(graph.edge_index.cpu(), graph.y.cpu() , method='edge')

    for exp in range(args.num_exp):
        # preprocess (data), init (split, trainer, model)
        graph, train_nodes, val_nodes, test_nodes = process_data(args, graph_orig)

        trainer = Trainer(args, graph, train_nodes, val_nodes, test_nodes)
        model = trainer.model_init()
        best_model = trainer.fit(graph, model)
        train_acc, val_acc, test_acc, test_corr, sp, eo = trainer.eval(graph, best_model)

        # save results
        train_accs.append(train_acc), val_accs.append(val_acc), test_accs.append(test_acc), test_corrs.append(test_corr)
        sp_list.append(sp), eo_list.append(eo), hc_list.append(hc)
        print_outcome(args, exp, test_acc, sp, eo, hc, print_summary=False)

        outputs = [train_accs, val_accs, test_accs, test_corrs, sp_list, eo_list, hc_list]

    
    train_accs, val_accs, test_accs, test_corrs, sp_list, eo_list, hc_list = outputs    
    print_outcome(args, exp, test_accs, sp_list, eo_list, hc_list, print_summary=True)

    return test_accs


def parameter_parser():
    parser = argparse.ArgumentParser()

    """DATASET"""
    parser.add_argument("--dataset", default='csbms',)

    """CSBMS"""
    parser.add_argument("--num-nodes", type=int, default=10000,)
    parser.add_argument("--d_pp", type=int, default=25,)
    parser.add_argument("--d_pn", type=int, default=25,)
    parser.add_argument("--d_np", type=int, default=25,)
    parser.add_argument("--d_nn", type=int, default=25,)
    parser.add_argument("--tau", type=float, default=0.1,)
    parser.add_argument("--mu_y", type=float, default=0.0,)
    parser.add_argument("--mu_s", type=float, default=0.0,)
    parser.add_argument("--sigma", type=float, default=1.0,)

    """EXPERIMENT"""
    parser.add_argument("--device", default='cuda:0',)
    parser.add_argument("--num-exp", type=int, default=5,)
    parser.add_argument("--model", default='simple-gnn',)
    parser.add_argument("--epochs", type=int, default=500,)
    parser.add_argument("--patience", type=int, default=100,)
    parser.add_argument("--split-type", type=str, default='ratio',)
    parser.add_argument("--train-ratio", type=float, default=0.5,)
    parser.add_argument("--val-ratio", type=float, default=0.25,)

    """HYPER-PARAMETERS"""
    parser.add_argument("--num-layers", type=int, default=1,)
    parser.add_argument("--lr", type=float, default=0.01,)
    parser.add_argument("--dr", type=float, default=0.0005,)
    parser.add_argument("--hid-dim", type=int, default=64,)

    return parser.parse_args()


if __name__ == "__main__":
    for mu_s in [0, 0.0625, 0.125, 0.25, 0.5]:
        for mu_y in [0, 0.0625, 0.125, 0.25, 0.5]:
            for h_y in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
                for h_s in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
                    args = parameter_parser()

                    args.tau = 0.1
                    args.mu_y = mu_y
                    args.mu_s = mu_s

                    n_edge = 100
                    args.d_pp = int(round(n_edge * h_y * h_s))
                    args.d_pn = int(round(n_edge * h_y * (1-h_s)))
                    args.d_np = int(round(n_edge * (1-h_y) * h_s))
                    args.d_nn = int(round(n_edge * (1-h_y) * (1-h_s)))

                    test_accs = main(args)
                    torch.cuda.empty_cache()
