import argparse
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from variation_gnn._models import GIN_lite
from variation_gnn._processor import graph_batch
from variation_gnn._utils import load_data, separate_data, pass_data_iteratively


def train(model, optimizer, train_graphs, args, epoch_idx):

    model.train()

    total_iters = args.iters_per_epoch
    pbar = tqdm(range(total_iters), unit='batch')

    loss_accum = 0
    for pos in pbar:
        selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]

        batch_graph = [train_graphs[idx] for idx in selected_idx]
        gb = graph_batch(batch_graph, args.device)

        batch_adj, batch_fts, num_nodes_ = gb.binary_adj, gb.all_fts, gb.num_nodes_
        output, _ = model(batch_adj, batch_fts, num_nodes_)

        labels = torch.LongTensor([graph.label for graph in batch_graph]).to(args.device)

        # compute loss
        loss = F.cross_entropy(output, labels)

        # backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

        # report
        pbar.set_description('epoch: %d' % (epoch_idx))

    average_loss = loss_accum
    print("loss training: %f" % (average_loss))

    return average_loss


def eval(model, _graphs, args, epoch_idx):
    model.eval()

    _batch_num = int(np.ceil(len(_graphs) / args.batch_size))
    _corrects = 0

    for i in range(_batch_num):
        _batch = _graphs[i*args.batch_size: (i+1)*args.batch_size]
        gb = graph_batch(_batch, args.device)

        batch_norm_adj, batch_bin_adj, batch_fts, batch_labels, batch_num_nodes_ \
            = gb.norm_adj, gb.binary_adj, gb.all_fts, gb.labels, gb.num_nodes_

        batch_adj = batch_bin_adj

        batch_scores, _ = model(batch_adj, batch_fts, batch_num_nodes_)
        batch_preds = batch_scores.max(1)[1]
        _corrects += batch_preds.eq(batch_labels).sum().cpu().item()
    
    acc = _corrects / float(len(_graphs))

    return acc


if __name__ == '__main__':

    # Training settings
    # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
    parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification')
    parser.add_argument('--dataset', type=str, default="MUTAG",
                        help='name of dataset (default: MUTAG)')
    parser.add_argument('--device_idx', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--iters_per_epoch', type=int, default=50,
                        help='number of iterations per each epoch (default: 50)')
    parser.add_argument('--epochs', type=int, default=350,
                        help='number of epochs to train (default: 350)')
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate (default: 0.01)')
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed for splitting the dataset into 10 (default: 0)')
    parser.add_argument('--fold_idx', type=int, default=0,
                        help='the index of fold in 10-fold validation. Should be less then 10.')
    parser.add_argument('--num_layers', type=int, default=5,
                        help='number of layers INCLUDING the input one (default: 5)')
    parser.add_argument('--num_mlp_layers', type=int, default=2,
                        help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.')
    parser.add_argument('--hidden_dim', type=int, default=64,
                        help='number of hidden units (default: 64)')
    parser.add_argument('--final_dropout', type=float, default=0.5,
                        help='final layer dropout (default: 0.5)')
    parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"],
                        help='Pooling for over nodes in a graph: sum or average')
    parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],
                        help='Pooling for over neighboring nodes: sum, average or max')
    parser.add_argument('--learn_eps', action="store_true",
                                        help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
    parser.add_argument('--degree_as_tag', action="store_true",
    					help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
    parser.add_argument('--filename', type = str, default = "",
                                        help='output file')
    args = parser.parse_args()

    # set up seeds and device
    torch.manual_seed(0)
    np.random.seed(0)
    args.device = torch.device("cuda:" + str(args.device_idx)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    graphs, num_classes = load_data(args.dataset, args.degree_as_tag)

    # 10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
    train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)

    model = GIN_lite(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes).to(args.device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    for epoch_idx in range(args.epochs):
        scheduler.step()

        avg_loss = train(model, optimizer, train_graphs, args, epoch_idx)
        # eval on training set:
        acc_train = eval(model, train_graphs, args, epoch_idx)
        # eval on test set:
        acc_test = eval(model, test_graphs, args, epoch_idx)
        print(f"in epoch {epoch_idx}, train acc = {acc_train:.4f}, test acc = {acc_test:.4f}")