import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph

# from model import GIN
from network import SAGNetworkHierarchical as SAGPool
import pickle
import tqdm
import random
from itertools import chain
import sys
import os

torch.set_num_threads(1)

def main(args):
    # device = ('cuda:' + args.gpu) if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    
    with open(f"data/graph_{args.graph}.pickle", "rb") as f:
        train, val, test = pickle.load(f)
    
    input_dim = 1
       
    model = SAGPool(in_dim=1, hid_dim=32, out_dim=1).to(device)
    model.train()
    
    gt_fn = {
        'maxdegree': (lambda g: g.in_degrees().max()),
        'harmonic': (lambda g: 1. / ((1. / (g.in_degrees().float() + 1e-9)).sum())),
        'invsize': (lambda g: (1. / g.num_nodes())),
    }
    
    curr_task = args.task
    batch_size = args.batch_size
    norm_limit = args.norm_limit
    settings_str = f"{curr_task}_general_none_sagpool_{args.opt_fn}_{args.lr}_{args.lr}_{norm_limit}_{args.seed}"
    checkpoint_name = f"./checkpoints/{settings_str}.pt"
    # _ff = open(log_file_name, "w")
    _ff = sys.stdout
    
    best_val_mape = 1e10
    model.load_state_dict(torch.load(checkpoint_name))
    errs = np.array([0. for _ in range(50, 100)])
    err_cnts = np.array([0. for _ in range(50, 100)])
    
    model.eval()
    with torch.no_grad():
        _val_mape, _test_mape = 0., 0.
        for ii in range(0, len(test), batch_size):
            _graphs = [test[_]['graph'] for _ in range(ii, min(len(test), ii + batch_size))]
            gt = torch.FloatTensor([gt_fn[curr_task](g) for g in _graphs]).to(device)
            gs = dgl.batch([dgl.add_self_loop(g) for g in _graphs]).to(device)
            preds = model(gs, torch.ones(gs.number_of_nodes(), input_dim).to(device))
            _test_mape += ((preds - gt).abs() / gt).sum().item()

            errs[np.array([g.num_nodes()-50 for g in _graphs])] += ((preds - gt).abs() / gt).detach().cpu().numpy()
            err_cnts[np.array([g.num_nodes()-50 for g in _graphs])] += 1
    # _ff.write('\n'.join(map(str, (errs / err_cnts).tolist())) + '\n')
    _ff.write(f'Result | test_loss {(_test_mape / len(test))}\n')
    _ff.flush()
    _ff.close()
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='graph-level extrapolation task')
    parser.add_argument("--lr", type=float, default=3e-3,
                        help="learning rate for parameters except p")
    parser.add_argument("--lr-pool", type=float, default=1e-2,
                        help="learning rate for p")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=32,
                        help="number of hidden units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden layers")
    parser.add_argument("--weight-decay", type=float, default=0,
                        help="Weight for L2 loss")
    parser.add_argument("--aggregator-type", type=str, default="none")
    parser.add_argument("--pooling-type", type=str, default="sagpool")
    parser.add_argument("--batch-size", type=int, default=50,
                        help="batch_size")
    parser.add_argument("--norm-limit", type=float, default=1e2)
    parser.add_argument("--task", type=str, default="maxdegree",
                        help="Task type: maxdegree/harmonic/invsize")
    parser.add_argument("--opt-fn", type=str, default="rmsprop",
                        help="Function type: rmsprop/adam/adamgan")
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--gtype", type=int, default=0)
    parser.add_argument("--graph", type=str, default="general")
    parser.add_argument("--seed", type=int, default=0)
    
    args = parser.parse_args()
    print(args)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    dgl.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    main(args)
