import argparse
import datetime
import os
import time
import warnings

import torch

from configs.data_config import add_data_config
from configs.model_config import add_model_config
from configs.training_config import add_training_config
from datasets.load_data import load_directed_graph
from logger import Logger
from model.model_init import ModelZoo
from task.node_classification import OODNodeClassification
from utils import seed_everything, get_params

import torch.multiprocessing as mp
import torch.distributed as dist

warnings.filterwarnings("ignore")

def multiprocess_train(rank, world_size, available_gpus, logger, args, dataset):
    torch.cuda.set_device(rank)
    device = available_gpus[rank]  
    
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    
    model_zoo = ModelZoo(logger, args, dataset, dataset.num_node, dataset.num_features, dataset.num_node_classes, dataset.y, dataset.test_idx, "node")
    run = OODNodeClassification(logger, dataset, model_zoo, normalize_times=args.normalize_times, lr=args.lr,
                                                 weight_decay=args.weight_decay, epochs=args.num_epochs, 
                                                early_stop=args.early_stop, device=device, walk_times=args.walk_time)
    logger.info("# NodeClassification Params:" + str(get_params(model_zoo.model_init())))
    dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(add_help=False)
    add_data_config(parser)
    add_model_config(parser)
    add_training_config(parser)
    args = parser.parse_args()
    dataset_name = args.data_name
    model_name = args.model_name
    now_time = datetime.datetime.now()
    log_dir = os.path.join("log", model_name, dataset_name, args.data_node_split)
    logger_name = os.path.join(log_dir, str(now_time.strftime('%Y-%m-%d %H-%M-%S')) + ".log")
    logger = Logger(logger_name)
    logger.info(f"program start: {now_time}")

    # set up seed
    logger.info(f"random seed: {args.seed}")
    seed_everything(args.seed)
    device = torch.device('cuda:{}'.format(args.gpu_id) if (args.use_cuda and torch.cuda.is_available()) else 'cpu')

    # set up datasets
    set_up_datasets_start_time = time.time()
    logger.info(f"Load unsigned & directed & unweighted network: {args.data_name}")
    dataset = load_directed_graph(logger, args, name=args.data_name, root=args.data_root, k=args.data_dimension_k,
                                  node_split=args.data_node_split, edge_split=args.data_edge_split,
                                  node_split_id=args.data_node_split_id, edge_split_id=args.data_edge_split_id)
    set_up_datasets_end_time = time.time()
    logger.info(f"datasets: {args.data_name}, root dir: {args.data_root}, node-level split method: "
                f"{args.data_node_split}, id: {args.data_node_split_id}, "
                f"edge-level split method: {args.data_edge_split}, id: {args.data_edge_split_id}, "
                f"the running time is: {round(set_up_datasets_end_time-set_up_datasets_start_time,4)}s")
    logger.info(f"num_epochs: {args.num_epochs}, early_stop: {args.early_stop}, lr: {args.lr}, weight_decay: {args.weight_decay}")
    logger.info(f"dataset.x.shape: {dataset.x.shape}")
    label_info = max(dataset.y)+1 if dataset.y is not None else -1
    logger.info(f"max(dataset.y)+1: {label_info}")
    logger.info(f"dataset.num_node: {dataset.num_node}")
    logger.info(f"Real edges -> dataset.num_edge: {dataset.num_edge}")
    logger.info(f"min(dataset.adj.data): {min(dataset.adj.data)}")
    logger.info(f"max(dataset.adj.data): {max(dataset.adj.data)}")
    logger.info(f"dataset.adj.data: {dataset.adj.data}")
    logger.info(f"cuda device: {args.gpu_id}")
    logger.info(f"walk_len: {args.walk_len}, walk_time: {args.walk_time}")
    
    model_zoo = ModelZoo(logger, args, dataset, dataset.num_node, dataset.num_features, dataset.num_node_classes, dataset.y, dataset.test_idx, "node")
    
    run = OODNodeClassification(logger, dataset, model_zoo, normalize_times=args.normalize_times, lr=args.lr,
                                                weight_decay=args.weight_decay, epochs=args.num_epochs, 
                                            early_stop=args.early_stop, device=device, walk_times=args.walk_time)
    logger.info("# NodeClassification Params:" + str(get_params(model_zoo.model_init())))

