import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph

from model import GIN
import pickle
import tqdm
import random
from itertools import chain
import sys
import os

torch.set_num_threads(1)
        
def main(args):
    # device = ('cuda:' + args.gpu) if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    
    with open(f"data/graph_{args.graph}.pickle", "rb") as f:
        train, val, test = pickle.load(f)
    
    input_dim = 1
       
    model = GIN(args.n_layers, 2, input_dim, args.n_hidden, 1, args.pooling_type, args.aggregator_type, general_mode=args.gtype, use_bias=False).to(device)
    model.train()
    
    gt_fn = {
        'maxdegree': (lambda g: g.in_degrees().max()),
        'harmonic': (lambda g: 1. / ((1. / (g.in_degrees().float() + 1e-9)).sum())),
        'invsize': (lambda g: (1. / g.num_nodes())),
    }
    
    curr_task = args.task
    batch_size = args.batch_size
    norm_limit = args.norm_limit
    
    settings_str = f"{curr_task}_general_{args.aggregator_type}_{args.pooling_type}_{args.opt_fn}_{args.lr_pool}_{args.lr}_{norm_limit}_{args.gtype}_{args.seed}"
    checkpoint_name = f"./checkpoints/{settings_str}.pt"
    _ff = sys.stdout
        
    model.load_state_dict(torch.load(checkpoint_name))
    errs = np.array([0. for _ in range(50, 100)])
    err_cnts = np.array([0. for _ in range(50, 100)])
    
    model.eval()
    with torch.no_grad():
        _val_mape, _test_mape = 0., 0.
        for ii in range(0, len(test), batch_size):
            _graphs = [test[_]['graph'] for _ in range(ii, min(len(test), ii + batch_size))]
            gs = dgl.batch(_graphs).to(device)
            preds = model(gs, torch.ones(gs.number_of_nodes(), input_dim).to(device))
            gt = torch.FloatTensor([gt_fn[curr_task](g) for g in _graphs]).to(device)
            _test_mape += ((preds - gt).abs() / gt).sum().item()
            errs[np.array([g.num_nodes()-50 for g in _graphs])] += ((preds - gt).abs() / gt).detach().cpu().numpy()
            err_cnts[np.array([g.num_nodes()-50 for g in _graphs])] += 1
    if args.graph == 'general':
        mape_list = (errs / err_cnts).tolist()
        for i in range(50):
            print(f'Test MAPE where |V|={i+50}: {mape_list[i]}')
        
    if args.pooling_type == 'general' and args.aggregator_type == 'general':
        _ff.write(f'Result | p+ (pool) {model.pool.p_pos.item():.3f} | q+ (pool) {model.pool.q_pos.item():.3f} | p- (pool) {model.pool.p_neg.item():.3f} | q- (pool) {model.pool.q_neg.item():.3f} | p+ (ginlayer) {model.ginlayers[0].agg_fn.p_pos.item():.3f} | q+ (ginlayer) {model.ginlayers[0].agg_fn.q_pos.item():.3f} | p- (ginlayer) {model.ginlayers[0].agg_fn.p_neg.item():.3f} | q- (ginlayer) {model.ginlayers[0].agg_fn.q_neg.item():.3f} | test_loss {(_test_mape / len(test))}\n')
    else:
        _ff.write(f'Result | test_loss {(_test_mape / len(test))}\n')
    _ff.flush()
    _ff.close()
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='graph-level extrapolation task')
    parser.add_argument("--lr", type=float, default=3e-3,
                        help="learning rate for parameters except p")
    parser.add_argument("--lr-pool", type=float, default=1e-2,
                        help="learning rate for p")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=32,
                        help="number of hidden units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden layers")
    parser.add_argument("--weight-decay", type=float, default=0,
                        help="Weight for L2 loss")
    parser.add_argument("--aggregator-type", type=str, default="general",
                        help="Aggregator type: general/sum/max/mean/min")
    parser.add_argument("--pooling-type", type=str, default="general",
                        help="Pooling type: general/sum/max/mean/min/sort/set2set")
    parser.add_argument("--batch-size", type=int, default=50,
                        help="batch_size")
    parser.add_argument("--norm-limit", type=float, default=1e2)
    parser.add_argument("--task", type=str, default="maxdegree",
                        help="Task type: maxdegree/harmonic/invsize")
    parser.add_argument("--opt-fn", type=str, default="rmsprop",
                        help="Function type: rmsprop/adam/adamgan")
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--gtype", type=int, default=0,
                        help="0: use both GNP+/GNP-, 1: use only GNP^+, 2: use only GNP^-")
    parser.add_argument("--graph", type=str, default="general")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    print(args)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    dgl.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    main(args)
