import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph

from network import SAGNetworkHierarchical as SAGPool
import pickle
import tqdm
import random
from itertools import chain
import sys
import os

torch.set_num_threads(1)
    
def load_opt_fn(args, model, _ff):
    non_pool_params = list(model.parameters())
    pool_params_p, pool_params_q = [], []
    
    optimizer_pool = None
    
    if args.opt_fn == 'rmsprop':
        _ff.write(f"rmsprop with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.RMSprop(non_pool_params, lr=args.lr, weight_decay=args.weight_decay)
    elif args.opt_fn == 'adamgan':
        _ff.write(f"adam_gan with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, betas=(0.5, 0.999), weight_decay=args.weight_decay) #, weight_decay=5e-4)
    else:
        _ff.write(f"adam with lr ({args.lr_pool}, {args.lr})\n")
        optimizer = torch.optim.Adam(non_pool_params, lr=args.lr, weight_decay=args.weight_decay) #, weight_decay=5e-4)
    
    return optimizer, None
        
def main(args):
    # device = ('cuda:' + args.gpu) if torch.cuda.is_available() else 'cpu'
    device = 'cuda:0'
    
    with open(f"data/graph_{args.graph}.pickle", "rb") as f:
        train, val, test = pickle.load(f)
    
    input_dim = 1
       
    model = SAGPool(in_dim=1, hid_dim=32, out_dim=1).to(device)
    model.train()
    
    gt_fn = {
        'maxdegree': (lambda g: g.in_degrees().max()),
        'harmonic': (lambda g: 1. / ((1. / (g.in_degrees().float() + 1e-9)).sum())),
        'invsize': (lambda g: (1. / g.num_nodes())),
    }
    
    curr_task = args.task
    batch_size = args.batch_size
    norm_limit = args.norm_limit
    settings_str = f"{curr_task}_{args.graph}_none_sagpool_{args.opt_fn}_{args.lr}_{args.lr}_{norm_limit}_{args.seed}"
    log_file_name = f"./logs/{settings_str}.txt"
    checkpoint_name = f"./checkpoints/{settings_str}.pt"
    _ff = open(log_file_name, "w")
    # _ff = sys.stdout
    
    optimizer, optimizer_pool = load_opt_fn(args, model, _ff)
    _ff.write(f"batch_size: {batch_size}, norm_limit: {norm_limit}\n")
    _ff.write(f"Task: {curr_task}, valid: {curr_task in gt_fn}\n")
    for name, p in model.named_parameters():
        _ff.write(name + "\t" + str(p.shape) + "\n")
        
    n_params = sum(p.numel() for p in model.parameters())
    n_nonpool, n_pool = 0, 0
    
    for group in optimizer.param_groups:
        for p in group['params']:
            n_nonpool += p.numel()
    if optimizer_pool is not None:
        for group in optimizer_pool.param_groups:
            for p in group['params']:
                n_pool += p.numel()
    _ff.write(f"# of parameters: {n_params}, (non-pool {n_nonpool}, pool {n_pool})\n")
    _ff.flush()
    
    best_val_mape = 1e10
        
    for i in range(args.n_epochs):
        random.shuffle(train)
        train_loss, train_mape = 0., 0.
        for ii in range(0, len(train), batch_size):
            _graphs = [train[_]['graph'] for _ in range(ii, min(len(train), ii + batch_size))]
            gt = torch.FloatTensor([gt_fn[curr_task](g) for g in _graphs]).to(device)
            gs = dgl.batch([dgl.add_self_loop(g) for g in _graphs]).to(device)
            
            optimizer.zero_grad()
            if optimizer_pool is not None:
                optimizer_pool.zero_grad()
            preds = model(gs, torch.ones(gs.number_of_nodes(), input_dim).to(device))
            running_loss = torch.sum((preds - gt) * (preds - gt))
            running_loss.backward() 
            
            train_loss += running_loss.item()
            train_mape += ((preds - gt).abs() / gt).sum().item()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), norm_limit)
            optimizer.step()
            if optimizer_pool is not None:
                optimizer_pool.step()
        
        model.eval()
        with torch.no_grad():
            _val_mape, _test_mape = 0., 0.
            for ii in range(0, len(val), batch_size):
                _graphs = [val[_]['graph'] for _ in range(ii, min(len(val), ii + batch_size))]
                gt = torch.FloatTensor([gt_fn[curr_task](g) for g in _graphs]).to(device)
                gs = dgl.batch([dgl.add_self_loop(g) for g in _graphs]).to(device)
                preds = model(gs, torch.ones(gs.number_of_nodes(), input_dim).to(device))
                _val_mape += ((preds - gt).abs() / gt).sum().item()
            
            for ii in range(0, len(test), batch_size):
                _graphs = [test[_]['graph'] for _ in range(ii, min(len(test), ii + batch_size))]
                gt = torch.FloatTensor([gt_fn[curr_task](g) for g in _graphs]).to(device)
                gs = dgl.batch([dgl.add_self_loop(g) for g in _graphs]).to(device)
                preds = model(gs, torch.ones(gs.number_of_nodes(), input_dim).to(device))
                _test_mape += ((preds - gt).abs() / gt).sum().item()
        
        with torch.no_grad():
            if args.aggregator_type == 'general':
                for layer in model.ginlayers:
                    layer.agg_fn.p_pos.clamp_(min=-50.0, max=50.0)
                    layer.agg_fn.p_neg.clamp_(min=-50.0, max=50.0)
            if args.pooling_type == 'general':
                model.pool.p_pos.clamp_(min=-50.0, max=50.0)
                model.pool.p_neg.clamp_(min=-50.0, max=50.0)
        model.train()
        
        if best_val_mape > _val_mape:
            best_val_mape = _val_mape
            torch.save(model.state_dict(), checkpoint_name)
            
        if args.pooling_type == 'general' and args.aggregator_type == 'general':
            _ff.write(f'Epoch #{i+1} | {model.pool.p_pos.item():.3f} | {model.pool.q_pos.item():.3f} | {model.pool.p_neg.item():.3f} | {model.pool.q_neg.item():.3f} | {model.ginlayers[0].agg_fn.p_pos.item():.3f} | {model.ginlayers[0].agg_fn.q_pos.item():.3f} | {model.ginlayers[0].agg_fn.p_neg.item():.3f} | {model.ginlayers[0].agg_fn.q_neg.item():.3f} | train_loss {(train_mape / len(train))} | val_loss {(_val_mape / len(val))} | test_loss {(_test_mape / len(test))}\n')
        else:
            _ff.write(f'Epoch #{i+1} | train_loss {(train_mape / len(train))} | val_loss {(_val_mape / len(val))} | test_loss {(_test_mape / len(test))}\n')
        _ff.flush()
    _ff.close()
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='graph-level extrapolation task')
    parser.add_argument("--lr", type=float, default=3e-3,
                        help="learning rate for parameters except p")
    parser.add_argument("--lr-pool", type=float, default=1e-2,
                        help="learning rate for p")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=32,
                        help="number of hidden units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden layers")
    parser.add_argument("--weight-decay", type=float, default=0,
                        help="Weight for L2 loss")
    parser.add_argument("--aggregator-type", type=str, default="none")
    parser.add_argument("--pooling-type", type=str, default="sagpool")
    parser.add_argument("--batch-size", type=int, default=50,
                        help="batch_size")
    parser.add_argument("--norm-limit", type=float, default=1e2)
    parser.add_argument("--task", type=str, default="maxdegree",
                        help="Task type: maxdegree/harmonic/invsize")
    parser.add_argument("--opt-fn", type=str, default="rmsprop",
                        help="Function type: rmsprop/adam/adamgan")
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--gtype", type=int, default=0)
    parser.add_argument("--graph", type=str, default="general")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    print(args)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    dgl.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    main(args)
