import argparse
import datetime
import os
import time
import warnings

import torch

from configs.data_config import add_data_config
from configs.model_config import add_model_config
from configs.training_config import add_training_config
from datasets.load_data import load_directed_graph
from logger import Logger
from model.model_init import ModelZoo
from task.unsupervised_node_classification import UnsupervisedNodeClassification
from utils import seed_everything, get_params

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

warnings.filterwarnings("ignore")

def multiprocess_train(rank, world_size, available_gpus, logger, args, dataset):
    torch.cuda.set_device(rank)
    device = available_gpus[rank]  
    
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    
    model_zoo = ModelZoo(logger, args, dataset, dataset.num_node, dataset.num_features, dataset.num_node_classes, dataset.y, dataset.test_idx, "node")
    model = model_zoo.model_init()
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    sampler = DistributedSampler(dataset.train_idx, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset.train_idx, batch_size=args.train_batch_size,sampler=sampler, drop_last=False)
    run = UnsupervisedNodeClassification(logger, args, dataset, model, normalize_times=args.normalize_times, lr=args.lr,
                                                 weight_decay=args.weight_decay, epochs=args.num_epochs, 
                                                 logepochs=args.num_val_epochs, early_stop=args.early_stop, device=device, walk_time=args.walk_time, train_loader=dataloader)
    logger.info("# NodeClassification Params:" + str(get_params(model_zoo.model_init())))
    dist.destroy_process_group()

if __name__ == "__main__":
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:1024'
    parser = argparse.ArgumentParser(add_help=False)
    add_data_config(parser)
    add_model_config(parser)
    add_training_config(parser)
    args = parser.parse_args()

    dataset_name = args.data_name
    model_name = args.model_name

    now_time = datetime.datetime.now()

    log_dir = os.path.join("log", model_name, dataset_name, args.data_node_split)
    logger_name = os.path.join(log_dir, str(now_time.strftime('%Y-%m-%d %H-%M-%S')) + ".log")
    logger = Logger(logger_name)

    logger.info(f"program start: {now_time}")

    # set up seed
    logger.info(f"random seed: {args.seed}")
    seed_everything(args.seed)
    device = torch.device('cuda:{}'.format(args.gpu_id) if (args.use_cuda and torch.cuda.is_available()) else 'cpu')
    
    # set up datasets
    set_up_datasets_start_time = time.time()
    logger.info(f"Load unsigned & directed & unweighted network: {args.data_name}")
    dataset = load_directed_graph(logger, args, name=args.data_name, root=args.data_root, k=args.data_dimension_k,
                                  node_split=args.data_node_split, edge_split=args.data_edge_split,
                                  node_split_id=args.data_node_split_id, edge_split_id=args.data_edge_split_id)
    set_up_datasets_end_time = time.time()
    logger.info(f"datasets: {args.data_name}, root dir: {args.data_root}, node-level split method: "
                f"{args.data_node_split}, id: {args.data_node_split_id}, "
                f"edge-level split method: {args.data_edge_split}, id: {args.data_edge_split_id}, "
                f"the running time is: {round(set_up_datasets_end_time-set_up_datasets_start_time,4)}s")
    logger.info(f"num_epochs: {args.num_epochs}, early_stop: {args.early_stop}, lr: {args.lr}, weight_decay: {args.weight_decay}")
    logger.info(f"dataset.x.shape: {dataset.x.shape}")
    label_info = max(dataset.y)+1 if dataset.y is not None else -1
    logger.info(f"max(dataset.y)+1: {label_info}")
    logger.info(f"dataset.num_node: {dataset.num_node}")
    logger.info(f"Real edges -> dataset.num_edge: {dataset.num_edge}")
    logger.info(f"min(dataset.adj.data): {min(dataset.adj.data)}")
    logger.info(f"max(dataset.adj.data): {max(dataset.adj.data)}")
    logger.info(f"dataset.adj.data: {dataset.adj.data}")
    logger.info(f"device: {device}")
    logger.info(f"walk_len: {args.walk_len}, walk_time: {args.walk_time}")
    
    if args.data_name in ("arxivdir"):
        os.environ['MASTER_ADDR'] = '127.0.0.1'  
        os.environ['MASTER_PORT'] = '29523'
        available_gpus = [3, 4, 5, 6, 7]
        world_size = len(available_gpus)
        mp.spawn(multiprocess_train, args=(world_size, available_gpus, logger, args, dataset), nprocs=world_size, join=True)
    
    model_zoo = ModelZoo(logger, args, dataset, dataset.num_node, dataset.num_features, dataset.num_node_classes, dataset.y, dataset.test_idx, "node")
    run = UnsupervisedNodeClassification(logger, args, dataset, model_zoo, normalize_times=args.normalize_times, lr=args.lr,
                                                weight_decay=args.weight_decay, epochs=args.num_epochs, 
                                                logepochs=args.num_val_epochs, early_stop=args.early_stop, device=device, walk_time=args.walk_time)
    logger.info("# NodeClassification Params:" + str(get_params(model_zoo.model_init())))

