"""
Training and testing for node selection tasks in bAbI
"""

import argparse
import time

import numpy as np
import torch
from data_utils import get_babi_dataloaders
from ggnn_ns import NodeSelectionGGNN
from torch.optim import Adam


def main(args):
    out_feats = {4: 4, 15: 5, 16: 6}
    n_etypes = {4: 4, 15: 2, 16: 2}

    train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(
        batch_size=args.batch_size,
        train_size=args.train_num,
        task_id=args.task_id,
        q_type=args.question_id,
    )

    model = NodeSelectionGGNN(
        annotation_size=1,
        out_feats=out_feats[args.task_id],
        n_steps=5,
        n_etypes=n_etypes[args.task_id],
    )
    opt = Adam(model.parameters(), lr=args.lr)

    print(f"Task {args.task_id}, question_id {args.question_id}")

    print(f"Training set size: {len(train_dataloader.dataset)}")
    print(f"Dev set size: {len(dev_dataloader.dataset)}")

    # training and dev stage
    for epoch in range(args.epochs):
        model.train()
        for i, batch in enumerate(train_dataloader):
            g, labels = batch
            loss, _ = model(g, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            print(f"Epoch {epoch}, batch {i} loss: {loss.data}")

        dev_preds = []
        dev_labels = []
        model.eval()
        for g, labels in dev_dataloader:
            with torch.no_grad():
                preds = model(g)
                preds = (
                    torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
                )
                labels = labels.data.numpy().tolist()
                dev_preds += preds
                dev_labels += labels
        acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
        acc = sum(acc) / len(acc)
        print(f"Epoch {epoch}, Dev acc {acc}")

    # test stage
    for i, dataloader in enumerate(test_dataloaders):
        print(f"Test set {i} size: {len(dataloader.dataset)}")

    test_acc_list = []
    for dataloader in test_dataloaders:
        test_preds = []
        test_labels = []
        model.eval()
        for g, labels in dataloader:
            with torch.no_grad():
                preds = model(g)
                preds = (
                    torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
                )
                labels = labels.data.numpy().tolist()
                test_preds += preds
                test_labels += labels
        acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
        acc = sum(acc) / len(acc)
        test_acc_list.append(acc)

    test_acc_mean = np.mean(test_acc_list)
    test_acc_std = np.std(test_acc_list)

    print(
        f"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Gated Graph Neural Networks for node selection tasks in bAbI"
    )
    parser.add_argument(
        "--task_id", type=int, default=16, help="task id from 1 to 20"
    )
    parser.add_argument(
        "--question_id", type=int, default=1, help="question id for each task"
    )
    parser.add_argument(
        "--train_num", type=int, default=50, help="Number of training examples"
    )
    parser.add_argument("--batch_size", type=int, default=10, help="batch size")
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument(
        "--epochs", type=int, default=100, help="number of training epochs"
    )

    args = parser.parse_args()

    main(args)
