import argparse
import datetime
import os
import time
import warnings

import torch
import numpy as np
from configs.data_config import add_data_config
from configs.model_config import add_model_config
from configs.training_config import add_training_config
from dataprocess.load_data import load_undirected_graph
from logger import Logger
from utils import seed_everything
from models.gsl.LLaTA import LLaTALearn

if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    parser = argparse.ArgumentParser(add_help=False)
    add_data_config(parser)
    add_model_config(parser)
    add_training_config(parser)
    args = parser.parse_args()

    dataset_name = args.data_name
    model_name = args.model_name

    now_time = datetime.datetime.now()

    log_dir = os.path.join("log", model_name, dataset_name)
    logger_name = os.path.join(log_dir, str(now_time).replace(':', '-') + ".log")
    logger = Logger(logger_name)

    logger.info(f"program start: {now_time}")

    # set up seed
    seed_everything(args.seed)
    device = torch.device('cuda:{}'.format(args.gpu_id) if (args.use_cuda and torch.cuda.is_available()) else 'cpu')
    

    # set up datasets
    set_up_datasets_start_time = time.time()
    if args.data_name in ['cora', 'citeseer', 'pubmed', 'instagram', 'reddit', 'wikics', 'arxiv2023', 'history', 'photo', 'children', 'amazonratings']:
        dataset = load_undirected_graph(args, name=args.data_name, root=args.data_root, k=args.data_dimension_k)
        
    set_up_datasets_end_time = time.time()

    logger.info(f"Method: {args.model_name}, Datasets: {args.data_name}, x_dim: {dataset.x.shape}")

    model = LLaTALearn(logger, dataset, args, device)

    model.execute()
    
    # Repeat the experiment
    # acc_list = []
    # time_list = []
    # for i in range(10):
    #     seed_everything(args.seed+100*i)
    #     acc, t= model.execute()
    #     acc_list.append(acc)
    #     time_list.append(t)
    # print(f'acc: {np.mean(acc_list)}±{np.var(acc_list)}, time: {np.mean(time_list)}s')
        


    