import torch
import torch.nn as nn

import argparse

from subdomain import GraphSelect
from trainer import GraphTrainer
from model import construct_model

from utils import get_writer

def main():
    args = get_args()
    if args.label == None:
        args.label = args.type
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.train_acc = args.best_val_acc = args.test_acc = 0
    args.best_epoch = 1
    torch.manual_seed(args.seed)
    print('==> Loading dataset..')
    root = '../dataset'
    types = 'transductive'
    if args.dataset in ['PPI']:
        root += '/PPI'
        types = 'inductive'
    save_dir = f'./DD/{args.dataset}'
    selector = GraphSelect(root, args)
    train_data, val_data, test_data = selector.graph_config(save_dir, types)

    print('==> Loading model..')
    num_features = selector.dataset[0].num_features if isinstance(selector.dataset, list) else selector.dataset.num_features
    num_classes = selector.dataset[0].num_classes if isinstance(selector.dataset, list) else selector.dataset.num_classes
    net = construct_model(num_features, num_classes, args)
    print(net)
    net = net.to(args.device)
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.wd)
    criterion = nn.CrossEntropyLoss()
    Trainer = GraphTrainer(types)

    print('==> Loading writer..')
    writer = get_writer(args)
    # writer = None

    print('==> Training model..')
    Trainer.train(net, criterion, train_data, val_data, test_data, optimizer, writer, args)

def get_args():
    parser = argparse.ArgumentParser(description='PyTorch Model Training')
    parser.add_argument('--lr', default=1e-2, type=float, help='learning rate')
    parser.add_argument('--wd', default=5e-4, type=float, help='weight decay')
    parser.add_argument('--model', '-m', default='GCN', help='select model')
    parser.add_argument('--type', '-t', default='full', help='select decomposition type')
    parser.add_argument('--dataset', '-data', default='Cora', help='select dataset')
    parser.add_argument('--hidden_features', default=16, type=int, help='number of hidden layer features')
    parser.add_argument('--nprocs', '-n', default=1, type=int, help='number of parallel prossesor')
    parser.add_argument('--depth', '-d', default=2, type=int, help='depth for layers')
    parser.add_argument('--epoch', '-e', default=200, type=int, help='number of training epoch')
    parser.add_argument('--label', '-lab', default=None, help='label the test')
    parser.add_argument('--seed', default=0, type=int, help='choose seed for test')
    parser.add_argument('--inductive', action='store_true', help='select test option')
    return parser.parse_args()

if __name__ == '__main__':
    main()
