import argparse
import pickle as pkl

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import load_data
from TAHIN import TAHIN
from utils import (
    evaluate_acc,
    evaluate_auc,
    evaluate_f1_score,
    evaluate_logloss,
)

import dgl


def main(args):
    # step 1: Check device
    if args.gpu >= 0 and torch.cuda.is_available():
        device = "cuda:{}".format(args.gpu)
    else:
        device = "cpu"

    # step 2: Load data
    (
        g,
        train_loader,
        eval_loader,
        test_loader,
        meta_paths,
        user_key,
        item_key,
    ) = load_data(args.dataset, args.batch, args.num_workers, args.path)
    g = g.to(device)
    print("Data loaded.")

    # step 3: Create model and training components
    model = TAHIN(
        g, meta_paths, args.in_size, args.out_size, args.num_heads, args.dropout
    )
    model = model.to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    print("Model created.")

    # step 4: Training
    print("Start training.")
    best_acc = 0.0
    kill_cnt = 0
    for epoch in range(args.epochs):
        # Training and validation using a full graph
        model.train()
        train_loss = []
        for step, batch in enumerate(train_loader):
            user, item, label = [_.to(device) for _ in batch]
            logits = model.forward(g, user_key, item_key, user, item)

            # compute loss
            tr_loss = criterion(logits, label)
            train_loss.append(tr_loss)

            # backward
            optimizer.zero_grad()
            tr_loss.backward()
            optimizer.step()

        train_loss = torch.stack(train_loss).sum().cpu().item()

        model.eval()
        with torch.no_grad():
            validate_loss = []
            validate_acc = []
            for step, batch in enumerate(eval_loader):
                user, item, label = [_.to(device) for _ in batch]
                logits = model.forward(g, user_key, item_key, user, item)

                # compute loss
                val_loss = criterion(logits, label)
                val_acc = evaluate_acc(
                    logits.detach().cpu().numpy(), label.detach().cpu().numpy()
                )
                validate_loss.append(val_loss)
                validate_acc.append(val_acc)

            validate_loss = torch.stack(validate_loss).sum().cpu().item()
            validate_acc = np.mean(validate_acc)

            # validate
            if validate_acc > best_acc:
                best_acc = validate_acc
                best_epoch = epoch
                torch.save(model.state_dict(), "TAHIN" + "_" + args.dataset)
                kill_cnt = 0
                print("saving model...")
            else:
                kill_cnt += 1
                if kill_cnt > args.early_stop:
                    print("early stop.")
                    print("best epoch:{}".format(best_epoch))
                    break

            print(
                "In epoch {}, Train Loss: {:.4f}, Valid Loss: {:.5}\n, Valid ACC: {:.5}".format(
                    epoch, train_loss, validate_loss, validate_acc
                )
            )

    # test use the best model
    model.eval()
    with torch.no_grad():
        model.load_state_dict(torch.load("TAHIN" + "_" + args.dataset))
        test_loss = []
        test_acc = []
        test_auc = []
        test_f1 = []
        test_logloss = []
        for step, batch in enumerate(test_loader):
            user, item, label = [_.to(device) for _ in batch]
            logits = model.forward(g, user_key, item_key, user, item)

            # compute loss
            loss = criterion(logits, label)
            acc = evaluate_acc(
                logits.detach().cpu().numpy(), label.detach().cpu().numpy()
            )
            auc = evaluate_auc(
                logits.detach().cpu().numpy(), label.detach().cpu().numpy()
            )
            f1 = evaluate_f1_score(
                logits.detach().cpu().numpy(), label.detach().cpu().numpy()
            )
            log_loss = evaluate_logloss(
                logits.detach().cpu().numpy(), label.detach().cpu().numpy()
            )

            test_loss.append(loss)
            test_acc.append(acc)
            test_auc.append(auc)
            test_f1.append(f1)
            test_logloss.append(log_loss)

        test_loss = torch.stack(test_loss).sum().cpu().item()
        test_acc = np.mean(test_acc)
        test_auc = np.mean(test_auc)
        test_f1 = np.mean(test_f1)
        test_logloss = np.mean(test_logloss)
        print(
            "Test Loss: {:.5}\n, Test ACC: {:.5}\n, AUC: {:.5}\n, F1: {:.5}\n, Logloss: {:.5}\n".format(
                test_loss, test_acc, test_auc, test_f1, test_logloss
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Parser For Arguments",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--dataset",
        default="movielens",
        help="Dataset to use, default: movielens",
    )
    parser.add_argument(
        "--path", default="./data", help="Path to save the data"
    )
    parser.add_argument("--model", default="TAHIN", help="Model Name")

    parser.add_argument("--batch", default=128, type=int, help="Batch size")
    parser.add_argument(
        "--gpu",
        type=int,
        default="0",
        help="Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0",
    )
    parser.add_argument(
        "--epochs", type=int, default=500, help="Maximum number of epochs"
    )
    parser.add_argument(
        "--wd", type=float, default=0, help="L2 Regularization for Optimizer"
    )
    parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=10,
        help="Number of processes to construct batches",
    )
    parser.add_argument(
        "--early_stop", default=15, type=int, help="Patience for early stop."
    )

    parser.add_argument(
        "--in_size",
        default=128,
        type=int,
        help="Initial dimension size for entities.",
    )
    parser.add_argument(
        "--out_size",
        default=128,
        type=int,
        help="Output dimension size for entities.",
    )

    parser.add_argument(
        "--num_heads", default=1, type=int, help="Number of attention heads"
    )
    parser.add_argument("--dropout", default=0.1, type=float, help="Dropout.")

    args = parser.parse_args()

    print(args)

    main(args)
