import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph

from model import GIN
import pickle
import tqdm
import random
from itertools import chain
import os, sys

torch.set_num_threads(1)

def main(args):
    # device = ('cuda:' + args.gpu) if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    
    with open(f"data/graph_{args.graph}.pickle", "rb") as f:
        train, val, test = pickle.load(f)    
    input_dim = 1
    
    model = GIN(args.n_layers, 2, input_dim, args.n_hidden, 1, args.aggregator_type, general_mode=args.gtype, use_bias=True).to(device)
    model.train()
    
    gt_fn = {
        'bfs': (lambda unweighted_dists, weighted_dists: (unweighted_dists <= 3.1).float()),
        'shortest': (lambda unweighted_dists, weighted_dists: weighted_dists),
    }
    mask_fn = {
        'bfs': (lambda unweighted_dists: torch.ones_like(unweighted_dists)),
        'shortest': (lambda unweighted_dists: (unweighted_dists <= 3.1).float()),
    }
    feat_fn = {
        'bfs': (lambda unweighted_dists: (unweighted_dists < 0.1).float()),
        'shortest': (lambda unweighted_dists: (unweighted_dists > 0.1).float() * (unweighted_dists.shape[-1] * 10)),
    }
    
    curr_task = args.task
    batch_size = args.batch_size
    norm_limit = args.norm_limit
    settings_str = f"{curr_task}_general_{args.aggregator_type}_{args.opt_fn}_{args.lr_pool}_{args.lr}_{norm_limit}_{args.gtype}_{args.seed}"
    checkpoint_name = f"./checkpoints/{settings_str}.pt"
    _ff = sys.stdout
    _ff.flush()
    
    best_val_mae = 1e10
    model.load_state_dict(torch.load(checkpoint_name))
    
    for i in range(1):
        n_test = 0.
        model.eval()
        with torch.no_grad():
            val_mae, test_mae = 0., 0.
            for ii in range(0, len(test), batch_size):
                _graphs = []
                for _ in range(ii, min(len(test), ii + batch_size)):
                    srcs, dsts = test[_]['graph'].edges()[0], test[_]['graph'].edges()[1]
                    g = dgl.add_self_loop(test[_]['graph'])
                    g.edata['_w'] = torch.ones(g.num_edges(), 1).to(g.device)
                    g.edata['_w'][g.edge_ids(range(g.num_nodes()), range(g.num_nodes()))] = torch.zeros(g.num_nodes(), 1).to(g.device)
                    if curr_task != 'bfs': g.edata['_w'][g.edge_ids(srcs, dsts)] = test[_]['edge_feats'].float().unsqueeze(-1)
                    _graphs.append(g)

                mask = torch.cat([mask_fn[curr_task](test[_]['one']) for _ in range(ii, min(len(test), ii + batch_size))], dim=-1).to(device)
                gt = torch.cat([gt_fn[curr_task](test[_]['one'], test[_]['w']) for _ in range(ii, min(len(test), ii + batch_size))], dim=-1).to(device)
                gs = dgl.batch(_graphs).to(device)
                    
                node_feats = torch.cat([feat_fn[curr_task](test[_]['one']) for _ in range(ii, min(len(test), ii + batch_size))], dim=-1).squeeze().unsqueeze(-1).to(device)
                edge_feats = gs.edata.pop('_w')
                
                preds = model(gs, node_feats, edge_attr=edge_feats)
                test_mae += (mask * (preds - gt).abs()).sum().item()
                n_test += mask.sum().item()
                
        if args.aggregator_type == 'general':
            _ff.write(f'Epoch #{i+1} | p+ (1st) {model.ginlayers[0].agg_fn.p_pos.item():.3f} | q+ (1st) {model.ginlayers[0].agg_fn.q_pos.item():.3f} | p- (1st) {model.ginlayers[0].agg_fn.p_neg.item():.3f} | q- (1st) {model.ginlayers[0].agg_fn.q_neg.item():.3f} | p+ (2nd) {model.ginlayers[1].agg_fn.p_pos.item():.3f} | q+ (2nd) {model.ginlayers[1].agg_fn.q_pos.item():.3f} | p- (2nd) {model.ginlayers[1].agg_fn.p_neg.item():.3f} | q- (2nd) {model.ginlayers[1].agg_fn.q_neg.item():.3f} | p+ (3rd) {model.ginlayers[2].agg_fn.p_pos.item():.3f} | q+ (3rd) {model.ginlayers[2].agg_fn.q_pos.item():.3f} | p- (3rd) {model.ginlayers[2].agg_fn.p_neg.item():.3f} | q- (3rd) {model.ginlayers[2].agg_fn.q_neg.item():.3f} | test_loss {(test_mae / n_test)}\n')
        else:
            _ff.write(f'Epoch #{i+1} | test_loss {(test_mae / n_test)}\n')
        _ff.flush()
    _ff.close()
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='node-level tasks')    
    # register_data_args(parser)
    parser.add_argument("--lr", type=float, default=3e-3,
                        help="learning rate for parameters except p")
    parser.add_argument("--lr-pool", type=float, default=3e-2,
                        help="learning rate for p")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=32,
                        help="number of hidden units")
    parser.add_argument("--n-layers", type=int, default=3,
                        help="number of hidden layers")
    parser.add_argument("--weight-decay", type=float, default=0,
                        help="Weight for L2 loss")
    parser.add_argument("--aggregator-type", type=str, default="general",
                        help="Aggregator type: general/sum/max/mean/min")
    parser.add_argument("--batch-size", type=int, default=50,
                        help="batch_size")
    parser.add_argument("--norm-limit", type=float, default=1e2)
    parser.add_argument("--task", type=str, default="bfs",
                        help="Task type: bfs/shortest")
    parser.add_argument("--opt-fn", type=str, default="rmsprop",
                        help="Function type: rmsprop/adam/adamgan")
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--gtype", type=int, default=0,
                        help="0: use both GNP+/GNP-, 1: use only GNP^+, 2: use only GNP^-")
    parser.add_argument("--graph", type=str, default="general")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    print(args)
    
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    dgl.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    main(args)
