import argparse

import torch
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

from utils import *


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default=None, help="Dataset name")
    parser.add_argument("--split", type=int, default=-1, help="Split index")
    parser.add_argument("--seed", type=int, default=123, help="Seed")
    parser.add_argument("--model_type", type=str, default=None, help="Type of model")
    parser.add_argument("--path", type=str, default=None, help="Save path")
    parser.add_argument("--hd", type=int, default=32, help="Hidden dim")
    parser.add_argument("--nl", type=int, default=3, help="Number of layers")
    parser.add_argument("--ld", type=int, default=32, help="Linear dim")
    parser.add_argument("--epochs", type=int, default=500, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--wd", type=float, default=0.005, help="Weight decay")
    parser.add_argument("--degree_attr", action="store_true", default=False, help="Degree as an attribute")
    args = parser.parse_args()
    return args


def main():
    args = args_parser()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = get_dataset(args.dataset, degree_attr=args.degree_attr)
    train_idxs, val_idxs, test_idxs = get_splits(
        args.dataset, size=len(dataset), seed=args.seed, split=args.split
    )

    dataset_train, dataset_val, dataset_test = (
        Subset(dataset, train_idxs),
        Subset(dataset, val_idxs),
        Subset(dataset, test_idxs),
    )

    dataloader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
    )
    dataloader_val = DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
    )
    dataloader_test = DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
    )

    model_args = {
        "model_type": args.model_type,
        "num_node_features": dataset_test[0].num_node_features,
        "num_classes": dataset.num_classes,
        "hidden_dim": args.hd,
        "num_layers": args.nl,
        "linear_dim": args.ld,
    }

    model = get_model(**model_args).to(device)
    model = train_model(
        model,
        dataloader_train,
        dataloader_val,
        dataloader_test,
        epochs=args.epochs,
        lr=args.lr,
        weight_decay=args.wd,
    )
    torch.save({"state_dict": model.state_dict(), "args": model_args}, args.path)


if __name__ == "__main__":
    main()
