import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
# from dgl.data import register_data_args, load_data

from model import GIN
import pickle
import tqdm
import random
from itertools import chain
import os, sys

torch.set_num_threads(1)

def load_opt_fn(args, model, _ff):
    non_pool_params = list(model.encoder.parameters()) + list(model.linears_prediction.parameters()) + list(chain.from_iterable([(list(layer.apply_func.parameters()) + list(layer.edge_encoder.parameters()) + [layer.eps]) for layer in model.ginlayers]))
    pool_params_p = list(chain.from_iterable([[layer.agg_fn.p_pos, layer.agg_fn.p_neg] for layer in model.ginlayers])) if args.aggregator_type == 'general' else []
    pool_params_q = list(chain.from_iterable([[layer.agg_fn.q_pos, layer.agg_fn.q_neg] for layer in model.ginlayers])) if args.aggregator_type == 'general' else []
    optimizer_pool = None
    
    if args.opt_fn == 'rmsprop':
        _ff.write(f"rmsprop with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.RMSprop(non_pool_params, lr=args.lr, weight_decay=args.weight_decay)
        if len(pool_params_p) > 0:
            optimizer_pool = torch.optim.RMSprop([{'params': pool_params_p},
                                                  {'params': pool_params_q, 'lr': args.lr}],
                                                 lr=args.lr_pool)
    elif args.opt_fn == 'adamgan':
        _ff.write(f"adam_gan with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, betas=(0.5, 0.999), weight_decay=args.weight_decay)
        if len(pool_params_p) > 0:
            optimizer_pool = torch.optim.Adam([{'params': pool_params_p},
                                               {'params': pool_params_q, 'lr': args.lr}],
                                              lr=args.lr_pool, betas=(0.5, 0.999))
    else:
        _ff.write(f"adam with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, weight_decay=args.weight_decay)
        if len(pool_params_p) > 0:
            optimizer_pool = torch.optim.Adam([{'params': pool_params_p},
                                               {'params': pool_params_q, 'lr': args.lr}],
                                              lr=args.lr_pool)
    return optimizer, optimizer_pool
        
def main(args):
    # device = ('cuda:' + args.gpu) if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    
    with open(f"data/graph_{args.graph}.pickle", "rb") as f:
        train, val, test = pickle.load(f)
    input_dim = 1
    
    model = GIN(args.n_layers, 2, input_dim, args.n_hidden, 1, args.aggregator_type, general_mode=args.gtype, use_bias=True).to(device)
    model.train()
    
    gt_fn = {
        'bfs': (lambda unweighted_dists, weighted_dists: (unweighted_dists <= 3.1).float()),
        'shortest': (lambda unweighted_dists, weighted_dists: weighted_dists),
    }
    mask_fn = {
        'bfs': (lambda unweighted_dists: torch.ones_like(unweighted_dists)),
        'shortest': (lambda unweighted_dists: (unweighted_dists <= 3.1).float()),
    }
    feat_fn = {
        'bfs': (lambda unweighted_dists: (unweighted_dists < 0.1).float()),
        'shortest': (lambda unweighted_dists: (unweighted_dists > 0.1).float() * (unweighted_dists.shape[-1] * 10)),
    }
    
    curr_task = args.task
    batch_size = args.batch_size
    norm_limit = args.norm_limit
    settings_str = f"{curr_task}_{args.graph}_{args.aggregator_type}_{args.opt_fn}_{args.lr_pool}_{args.lr}_{norm_limit}_{args.gtype}_{args.seed}"
    log_file_name = f"./logs/{settings_str}.txt"
    checkpoint_name = f"./checkpoints/{settings_str}.pt"
    _ff = open(log_file_name, "w")
    
    optimizer, optimizer_pool = load_opt_fn(args, model, _ff)
    _ff.write(f"batch_size: {batch_size}, norm_limit: {norm_limit}\n")
    _ff.write(f"Task: {curr_task}, valid: {curr_task in gt_fn}\n")
    for name, p in model.named_parameters():
        _ff.write(name + "\t" + str(p.shape) + "\n")
        
    n_params = sum(p.numel() for p in model.parameters())
    n_nonpool, n_pool = 0, 0
    
    # for p, name in model.named_parameters():
    #     print(p, name.shape)
        
    for group in optimizer.param_groups:
        for p in group['params']:
            n_nonpool += p.numel()
    if optimizer_pool is not None:
        for group in optimizer_pool.param_groups:
            for p in group['params']:
                n_pool += p.numel()
    _ff.write(f"# of parameters: {n_params}, (non-pool {n_nonpool}, pool {n_pool})")
    for p, _ in model.named_parameters():
        _ff.write(p + "\n")
    _ff.flush()
    
    best_val_mae = 1e10
    
    for i in range(args.n_epochs):
        random.shuffle(train)
        train_loss, train_mae = 0., 0.
        n_train, n_val, n_test = 0., 0., 0.
        for ii in range(0, len(train), batch_size):
            _graphs = []
            for _ in range(ii, min(len(train), ii + batch_size)):
                srcs, dsts = train[_]['graph'].edges()[0], train[_]['graph'].edges()[1]
                g = dgl.add_self_loop(train[_]['graph'])
                g.edata['_w'] = torch.ones(g.num_edges(), 1).to(g.device)
                g.edata['_w'][g.edge_ids(range(g.num_nodes()), range(g.num_nodes()))] = torch.zeros(g.num_nodes(), 1).to(g.device)
                if curr_task != 'bfs': g.edata['_w'][g.edge_ids(srcs, dsts)] = train[_]['edge_feats'].float().unsqueeze(-1)
                _graphs.append(g)
                
            mask = torch.cat([mask_fn[curr_task](train[_]['one']) for _ in range(ii, min(len(train), ii + batch_size))], dim=-1).to(device)
            gt = torch.cat([gt_fn[curr_task](train[_]['one'], train[_]['w']) for _ in range(ii, min(len(train), ii + batch_size))], dim=-1).to(device)
            gs = dgl.batch(_graphs).to(device)
            
            node_feats = torch.cat([feat_fn[curr_task](train[_]['one']) for _ in range(ii, min(len(train), ii + batch_size))], dim=-1).squeeze().unsqueeze(-1).to(device)
            edge_feats = gs.edata.pop('_w')
            
            optimizer.zero_grad()
            if optimizer_pool is not None: optimizer_pool.zero_grad()
            preds = model(gs, node_feats, edge_attr=edge_feats)
            running_loss = torch.sum((preds - gt) * (preds - gt) * mask)
            running_loss.backward() 
            
            train_loss += running_loss.item()
            train_mae += (mask * (preds - gt)).abs().sum().item()
            n_train += mask.sum().item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), norm_limit)
            optimizer.step()
            if optimizer_pool is not None: optimizer_pool.step()
        
        model.eval()
        with torch.no_grad():
            val_mae, test_mae = 0., 0.
            for ii in range(0, len(val), batch_size):
                _graphs = []
                for _ in range(ii, min(len(val), ii + batch_size)):
                    srcs, dsts = val[_]['graph'].edges()[0], val[_]['graph'].edges()[1]
                    g = dgl.add_self_loop(val[_]['graph'])
                    g.edata['_w'] = torch.ones(g.num_edges(), 1).to(g.device)
                    g.edata['_w'][g.edge_ids(range(g.num_nodes()), range(g.num_nodes()))] = torch.zeros(g.num_nodes(), 1).to(g.device)
                    if curr_task != 'bfs': g.edata['_w'][g.edge_ids(srcs, dsts)] = val[_]['edge_feats'].float().unsqueeze(-1)
                    _graphs.append(g)

                mask = torch.cat([mask_fn[curr_task](val[_]['one']) for _ in range(ii, min(len(val), ii + batch_size))], dim=-1).to(device)
                gt = torch.cat([gt_fn[curr_task](val[_]['one'], val[_]['w']) for _ in range(ii, min(len(val), ii + batch_size))], dim=-1).to(device)
                gs = dgl.batch(_graphs).to(device)

                node_feats = torch.cat([feat_fn[curr_task](val[_]['one']) for _ in range(ii, min(len(val), ii + batch_size))], dim=-1).squeeze().unsqueeze(-1).to(device)
                edge_feats = gs.edata.pop('_w')
                
                preds = model(gs, node_feats, edge_attr=edge_feats)
                val_mae += (mask * (preds - gt).abs()).sum().item()
                n_val += mask.sum().item()
                
            for ii in range(0, len(test), batch_size):
                _graphs = []
                for _ in range(ii, min(len(test), ii + batch_size)):
                    srcs, dsts = test[_]['graph'].edges()[0], test[_]['graph'].edges()[1]
                    g = dgl.add_self_loop(test[_]['graph'])
                    g.edata['_w'] = torch.ones(g.num_edges(), 1).to(g.device)
                    g.edata['_w'][g.edge_ids(range(g.num_nodes()), range(g.num_nodes()))] = torch.zeros(g.num_nodes(), 1).to(g.device)
                    if curr_task != 'bfs': g.edata['_w'][g.edge_ids(srcs, dsts)] = test[_]['edge_feats'].float().unsqueeze(-1)
                    _graphs.append(g)

                mask = torch.cat([mask_fn[curr_task](test[_]['one']) for _ in range(ii, min(len(test), ii + batch_size))], dim=-1).to(device)
                gt = torch.cat([gt_fn[curr_task](test[_]['one'], test[_]['w']) for _ in range(ii, min(len(test), ii + batch_size))], dim=-1).to(device)
                gs = dgl.batch(_graphs).to(device)

                node_feats = torch.cat([feat_fn[curr_task](test[_]['one']) for _ in range(ii, min(len(test), ii + batch_size))], dim=-1).squeeze().unsqueeze(-1).to(device)
                edge_feats = gs.edata.pop('_w')
                
                preds = model(gs, node_feats, edge_attr=edge_feats)
                test_mae += (mask * (preds - gt).abs()).sum().item()
                n_test += mask.sum().item()
                
        with torch.no_grad():
            if args.aggregator_type == 'general':
                for layer in model.ginlayers:
                    layer.agg_fn.p_pos.clamp_(min=-50.0, max=50.0)
                    layer.agg_fn.p_neg.clamp_(min=-50.0, max=50.0)
        model.train()
        
        if best_val_mae > val_mae:
            best_val_mae = val_mae
            torch.save(model.state_dict(), checkpoint_name)
            
        if args.aggregator_type == 'general':
            _ff.write(f'Epoch #{i+1} | p+ (1st) {model.ginlayers[0].agg_fn.p_pos.item():.3f} | q+ (1st) {model.ginlayers[0].agg_fn.q_pos.item():.3f} | p- (1st) {model.ginlayers[0].agg_fn.p_neg.item():.3f} | q- (1st) {model.ginlayers[0].agg_fn.q_neg.item():.3f} | p+ (2nd) {model.ginlayers[1].agg_fn.p_pos.item():.3f} | q+ (2nd) {model.ginlayers[1].agg_fn.q_pos.item():.3f} | p- (2nd) {model.ginlayers[1].agg_fn.p_neg.item():.3f} | q- (2nd) {model.ginlayers[1].agg_fn.q_neg.item():.3f} | p+ (3rd) {model.ginlayers[2].agg_fn.p_pos.item():.3f} | q+ (3rd) {model.ginlayers[2].agg_fn.q_pos.item():.3f} | p- (3rd) {model.ginlayers[2].agg_fn.p_neg.item():.3f} | q- (3rd) {model.ginlayers[2].agg_fn.q_neg.item():.3f} | train_loss {(train_mae / n_train)} | val_loss {(val_mae / n_val)} | test_loss {(test_mae / n_test)}\n')
        else:
            _ff.write(f'Epoch #{i+1} | train_loss {(train_mae / n_train)} | val_loss {(val_mae / n_val)} | test_loss {(test_mae / n_test)}\n')
        _ff.flush()
    _ff.close()
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='node-level tasks')    
    # register_data_args(parser)
    parser.add_argument("--lr", type=float, default=3e-3,
                        help="learning rate for parameters except p")
    parser.add_argument("--lr-pool", type=float, default=3e-2,
                        help="learning rate for p")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=32,
                        help="number of hidden units")
    parser.add_argument("--n-layers", type=int, default=3,
                        help="number of hidden layers")
    parser.add_argument("--weight-decay", type=float, default=0,
                        help="Weight for L2 loss")
    parser.add_argument("--aggregator-type", type=str, default="general",
                        help="Aggregator type: general/sum/max/mean/min")
    parser.add_argument("--batch-size", type=int, default=50,
                        help="batch_size")
    parser.add_argument("--norm-limit", type=float, default=1e2)
    parser.add_argument("--task", type=str, default="bfs",
                        help="Task type: bfs/shortest")
    parser.add_argument("--opt-fn", type=str, default="rmsprop",
                        help="Function type: rmsprop/adam/adamgan")
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--gtype", type=int, default=0,
                        help="0: use both GNP+/GNP-, 1: use only GNP^+, 2: use only GNP^-")
    parser.add_argument("--graph", type=str, default="general")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    print(args)
    
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    dgl.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    main(args)
