import logging
import torch
import sys
import os

from logger import Logger
from utils import set_random_seed, get_args
from task import NodeTask

if __name__ == '__main__':
    args = get_args()
    pretrain_task = args.pretrain_task
    device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')

    if not os.path.exists('log'):
        os.makedirs('log')

    for seed in range(5):
        print(f'Dataset: {args.dataset_name} | Task: {pretrain_task} | Shots: {args.shots} | '
              f'k: {args.k} | m: {args.m} | r: {args.r} | Seed: {seed}')

        filename = f'log/{args.dataset_name}_{pretrain_task}_{args.shots}_k{args.k}_m{args.m}_r{args.r}_{seed}.log'
        formatter = logging.Formatter('%(asctime)s - %(message)s')
        logger = Logger(filename, formatter)

        set_random_seed(seed)

        task = NodeTask(args.dataset_name, args.shots, args.hidden_dim, device, pretrain_task, logger, args)
        task.train(args.batch_size, lr=args.lr, epochs=args.epochs)