import argparse
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from variation_gnn._models import EPM_VAE, EPM_VAE_Loss
from variation_gnn._processor import graph_batch
from variation_gnn._utils import load_data, separate_data, pass_data_iteratively, debug_print, debugger


def warm_up(model, optimizer, train_graphs, args, epoch_idx):
    """
    train the model for 1 epoch
    """
    debug_ckpt = global_debug_ckpt + '/warm_up'

    model.train()
    tot_iters = args.iters_per_epoch
    pbar = tqdm(range(tot_iters), unit='batch')

    loss_epoch = 0.

    for pos in pbar:
        # select a batch of graphs and convert them to graph_batch 
        selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]
        batch_graph = [train_graphs[idx] for idx in selected_idx]
        gb = graph_batch(batch_graph, args.device)

        # return preds and inferred parameters
        batch_adj, batch_fts = gb.norm_adj, gb.all_fts
        preds, k, lbd = model(batch_adj, batch_fts)
        k.retain_grad()

        # compute loss
        p_labs, n_labs, kl_w = gb.pos_labels_adj, gb.neg_labels_adj, gb.kl_weights
        L_inf = EPM_VAE_Loss(p_labs, n_labs, kl_w)
        loss = L_inf(preds, k, lbd)

        # backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            # debug_print(k.grad, ckpt = debug_ckpt + '/if optimizer is not None')
            optimizer.step()
        
        loss_epoch += loss.detach().cpu().numpy()

        # report
        pbar.set_description('epoch: %d' % (epoch_idx))
    
    print("training loss = {:.4f}".format(loss_epoch / tot_iters))


def execute(args):
    # set up seeds and gpu device
    torch.manual_seed(0)
    np.random.seed(0)
    args.device = torch.device("cuda:" + str(args.gpu_idx)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)
    
    # read dataset, split into train / test
    graphs, num_classes = load_data(args.dataset, args.degree_as_tag)
    train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)

    # initialize model, optimizer
    ft_in = train_graphs[0].node_features.shape[1]
    epm_model = EPM_VAE(ft_in, args.inf_h1_dim, args.inf_h2_dim, args.inf_dropout).to(args.device)
    epm_optim = optim.Adam(epm_model.parameters(), lr=args.wu_lr, weight_decay=args.inf_l2coef)

    # run warm-up
    for wu_epoch in range(args.wu_epochs):
        avg_loss = warm_up(epm_model, epm_optim, train_graphs, args, wu_epoch)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='PyTorch graph convolutional neural net for whole-graph classification')
    parser.add_argument('--dataset', type=str, default="NCI1",
                        help='name of dataset (default: MUTAG)')
    parser.add_argument('--gpu_idx', type=int, default=0,
                        help='which gpu to use if any (default: 0)')

    # batch training hyperparameters
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--iters_per_epoch', type=int, default=50,
                        help='number of iterations per each epoch (default: 50)')
    parser.add_argument('--epochs', type=int, default=350,
                        help='number of epochs to train (default: 350)')
    parser.add_argument('--wu_epochs', type=int, default=1000,
                        help='number of epochs to warm-up (default: 1000)')

    # inference module hyperparameters
    parser.add_argument('--wu_lr', type=float, default=.01,
                        help='warming-up learning rate (default: 0.01)')
    parser.add_argument('--inf_h1_dim', type=int, default=200,
                        help='number of hidden units in the 1st layer of the inference module (default: 200)')
    parser.add_argument('--inf_h2_dim', type=int, default=100,
                        help='number of hidden units in the 2nd layer of the inference module (default: 100)')
    parser.add_argument('--inf_dropout', type=float, default=0.5,
                        help='inference module dropout rate (default: 0.5)')
    parser.add_argument('--inf_l2coef', type=float, default=0.,
                        help='inference module l2 reg coef (default: 0.)')

    # global randomness
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed for splitting the dataset into 10 (default: 0)')

    # IOs
    parser.add_argument('--degree_as_tag', action="store_true",
                        help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
    parser.add_argument('--fold_idx', type=int, default=0,
                        help='the index of fold in 10-fold validation. Should be less then 10.')
    parser.add_argument('--filename', type=str, default="",
                        help='output file')
    args = parser.parse_args()

    global_debug_ckpt = 'epm_main'

    execute(args)
