import datetime
import os
import torch
import logging

import custom_modules
from custom_modules.agg_runs import agg_runs
from custom_modules.optimizer.extra_optimizers import ExtendedSchedulerConfig

from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    set_cfg,
    load_cfg,
    makedirs_rm_exist,
)

# from torch_geometric.graphgym.loader import create_loader
from custom_modules.loader.master_loader import create_loader, create_loader_distributed

from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.optim import (
    # create_optimizer,
    create_scheduler,
    OptimizerConfig,
)

from custom_modules.optimizer.extra_optimizers import create_optimizer

from torch_geometric.graphgym.model_builder import create_model, GraphGymModule
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.graphgym.register import train_dict
from torch_geometric import seed_everything

import glob
from custom_modules.logger import create_logger

from custom_modules.utils import cfg_to_dict, flatten_dict, make_wandb_name

import random
import string

import torch.multiprocessing as mp

mp.set_sharing_strategy("file_system")
import torch.distributed as dist

from torch.nn.parallel import DistributedDataParallel as DDP

import re

import numpy as np
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = True  # Default False in PyTorch 1.12+
torch.backends.cudnn.allow_tf32 = True  # Default True


from torch.distributed.utils import _cast_forward_inputs
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

logger = logging.getLogger(__name__)


class CustomDDP(DDP):

    def _pre_forward(self, *inputs, **kwargs):
        if not self._lazy_init_ran:
            self._lazy_init()
        if self._delay_all_reduce_all_params:
            return inputs, kwargs

        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            assert self.logger is not None
            self.logger.set_runtime_stats_and_log()
            self.reducer.prepare_for_forward()

        work = Join.notify_join_context(self)
        if work:
            self.reducer._set_forward_pass_work_handle(
                work, self._divide_by_initial_world_size  # type: ignore[arg-type]
            )

        if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
            logger.info("Reducer buckets have been rebuilt in this iteration.")
            self._has_rebuilt_buckets = True

        if self._check_sync_bufs_pre_fwd():
            self._sync_buffers()

        if self._join_config.enable:
            # Notify joined ranks whether they should sync in backwards pass or not.
            self._check_global_requires_backward_grad_sync(is_joined_rank=False)

        if self.mixed_precision is not None:
            inputs, kwargs = _cast_forward_inputs(
                self.mixed_precision.param_dtype,
                *inputs,
                **kwargs,
            )
        return inputs, kwargs


def new_optimizer_config(cfg):
    return OptimizerConfig(
        optimizer=cfg.optim.optimizer,
        base_lr=cfg.optim.base_lr,
        weight_decay=cfg.optim.weight_decay,
        momentum=cfg.optim.momentum,
    )


def new_scheduler_config(cfg):
    return ExtendedSchedulerConfig(
        scheduler=cfg.optim.scheduler,
        steps=cfg.optim.steps,
        lr_decay=cfg.optim.lr_decay,
        max_epoch=cfg.optim.max_epoch,
        reduce_factor=cfg.optim.reduce_factor,
        schedule_patience=cfg.optim.schedule_patience,
        min_lr=cfg.optim.min_lr,
        num_warmup_epochs=cfg.optim.num_warmup_epochs,
        train_mode=cfg.train.mode,
        eval_period=cfg.train.eval_period,
    )


# def params_count(model):
#     """Computes the number of parameters.

#     Args:
#         model (nn.Module): PyTorch model
#     """
#     return sum([p.numel() for p in model.parameters()])


def params_count_no_mlp(model):
    """Computes the number of parameters, excluding those in 'feat_emb' or 'readout'.

    Args:
        model (nn.Module): PyTorch model

    Returns:
        int: Number of parameters excluding specific names.
    """
    return sum(p.numel() for name, p in model.named_parameters() if 'feat_emb' not in name and 'readout' not in name)


def custom_set_out_dir(cfg, cfg_fname, name_tag):
    """Set custom main output directory path to cfg.
    Include the config filename and name_tag in the new :obj:`cfg.out_dir`.

    Args:
        cfg (CfgNode): Configuration node
        cfg_fname (string): Filename for the yaml format configuration file
        name_tag (string): Additional name tag to identify this execution of the
            configuration file, specified in :obj:`cfg.name_tag`
    """
    run_name = os.path.splitext(os.path.basename(cfg_fname))[0]
    run_name += f"-{name_tag}" if name_tag else ""
    cfg.out_dir = os.path.join(cfg.out_dir, run_name)


def custom_set_run_dir(cfg, run_id):
    """Custom output directory naming for each experiment run.

    Args:
        cfg (CfgNode): Configuration node
        run_id (int): Main for-loop iter id (the random seed or dataset split)
    """
    cfg.run_dir = os.path.join(cfg.out_dir, str(run_id))
    os.makedirs(cfg.run_dir, exist_ok=True)
    # Make output directory
    # if cfg.train.auto_resume:
    #     os.makedirs(cfg.run_dir, exist_ok=True)
    # else:
    #     makedirs_rm_exist(cfg.run_dir)


def run_loop_settings():
    """Create main loop execution settings based on the current cfg.

    Configures the main execution loop to run in one of two modes:
    1. 'multi-seed' - Reproduces default behaviour of GraphGym when
        args.repeats controls how many times the experiment run is repeated.
        Each iteration is executed with a random seed set to an increment from
        the previous one, starting at initial cfg.seed.
    2. 'multi-split' - Executes the experiment run over multiple dataset splits,
        these can be multiple CV splits or multiple standard splits. The random
        seed is reset to the initial cfg.seed value for each run iteration.

    Returns:
        List of run IDs for each loop iteration
        List of rng seeds to loop over
        List of dataset split indices to loop over
    """
    if len(cfg.run_multiple_splits) == 0:
        # 'multi-seed' run mode
        num_iterations = args.repeat
        seeds = [cfg.seed + x for x in range(num_iterations)]
        split_indices = [cfg.dataset.split_index] * num_iterations
        run_ids = seeds
    else:
        # 'multi-split' run mode
        if args.repeat != 1:
            raise NotImplementedError(
                "Running multiple repeats of multiple "
                "splits in one run is not supported."
            )
        num_iterations = len(cfg.run_multiple_splits)
        seeds = [cfg.seed] * num_iterations
        split_indices = cfg.run_multiple_splits
        run_ids = split_indices
    return run_ids, seeds, split_indices


def generate_random_string(length=10):
    # Define the characters you want to include in the random string
    characters = string.ascii_letters + string.digits + string.punctuation

    # Generate a random string of the specified length
    random_string = "".join(random.choice(characters) for _ in range(length))

    return random_string


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12361"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def calculate_lr(
    num_nodes,
    interpolation="linear",
    min_nodes=256,
    max_nodes=20000,
    min_lr=0.0001,
    max_lr=0.001,
):
    if num_nodes <= min_nodes:
        return max_lr
    elif num_nodes >= max_nodes:
        return min_lr
    else:
        # Linear interpolation
        if interpolation == "linear":
            lr = max_lr - (
                (num_nodes - min_nodes) * (max_lr - min_lr) / (max_nodes - min_nodes)
            )
            return lr
        elif interpolation == "custom":
            return lr
        elif interpolation == "log":
            log_ratio = (np.log(num_nodes) - np.log(min_nodes)) / (
                np.log(max_nodes) - np.log(min_nodes)
            )
            lr = max_lr - (log_ratio * (max_lr - min_lr))
        return lr


def create_param_groups(
    model,
    nodes_num_dict,
    multi_graph=True,
    min_lr=0.0001,
    max_lr=0.001,
    base_lr=0.0001,
    lr_dict=None,
):
    if multi_graph:

        param_groups = []
        dataset_names = nodes_num_dict.keys()
        num_nodes_list = []
        for data_name in dataset_names:
            num_nodes_list.append(nodes_num_dict[data_name])
        for name, param in model.named_parameters():
            matched = False
            if param.requires_grad:
                for data_name in dataset_names:
                    if data_name + "." in name:
                        print(name)
                        try:
                            param_groups.append(
                                {
                                    "params": param,
                                    "lr": lr_dict[data_name],
                                    "name": data_name,
                                }
                            )
                        except:
                            param_groups.append(
                                {
                                    "params": param,
                                    "lr": calculate_lr(
                                        nodes_num_dict[data_name],
                                        interpolation="log",
                                        min_nodes=np.min(num_nodes_list),
                                        max_nodes=np.max(num_nodes_list),
                                        min_lr=min_lr,
                                        max_lr=max_lr,
                                    ),
                                    "name": data_name,
                                }
                            )
                        matched = True
                        break
                if not matched:
                    param_groups.append(
                        {"params": param, "lr": base_lr, "name": "base_perceiver"}
                    )
        return param_groups
    else:
        param_groups = []
        for name, param in model.named_parameters():
            if param.requires_grad:
                param_groups.append(
                    {"params": param, "lr": base_lr, "name": "base_perceiver"}
                )

        return param_groups


def main(
    rank,
    world_size,
    shared_graph_dataset_dict,
    saint_samplers,
    node_dataset_loaded_list_train,
    nodes_num_dict,
    syn_graph_config,
    optim_lr_dict=None,
):
    from torch_geometric.graphgym.config import (
        cfg,
        dump_cfg,
        set_cfg,
        load_cfg,
    )

    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
    # Load cmd line args
    args = parse_args()
    # Load config file
    set_cfg(cfg)
    load_cfg(cfg, args)
    custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag)

    group_id = generate_random_string()
    # setup the process groups
    setup(rank, world_size)

    cfg.dataset.split_index = 0
    cfg.seed = 0
    cfg.run_id = 0
    cfg.group_id = group_id
    cfg.rank = rank
    cfg.world_size = world_size

    seed_everything(cfg.seed)

    if cfg.rank == 0:

        if cfg.wandb.use:
            try:
                import wandb
            except:
                raise ImportError("WandB is not installed.")

            custom_dir = f"{cfg.out_dir}/distributed_results/wandblogging"
            os.makedirs(custom_dir, exist_ok=True)
            # wandb.init(
            #     project=cfg.wandb.project,
            #     dir=custom_dir,
            # )
            wandb.init(
                project="GraphFM",
                entity="bottleneck",
                dir=custom_dir,
            )
            wandb.config.update(cfg_to_dict(cfg))
            pattern = r"run-([^/]+)/files"
            run_name = re.search(pattern, wandb.run.dir)
            cfg.run_dir = (
                f"{cfg.out_dir}/distributed_results/local_folder/wandb/{run_name[0]}"
            )
            cfg.out_dir = wandb.run.dir

            # define our custom x axis metric
            wandb.define_metric("epoch")
            # define which metrics will be plotted against it
            wandb.define_metric("train_loss", step_metric="epoch")
            wandb.define_metric("LR/*", step_metric="epoch")
            wandb.define_metric("ooms", step_metric="epoch")
        else:
            custom_set_run_dir(cfg, cfg.run_id)

    dump_cfg(cfg)
    # prepare the dataloader
    # train_loader, graph_dataset_dict = create_loader_distributed(rank, world_size)
    train_loader = create_loader_distributed(
        rank,
        world_size,
        shared_graph_dataset_dict,
        saint_samplers,
        node_dataset_loaded_list_train,
    )

    if cfg.dataset_multi.use_synthetic:
        cfg.dataset_multi.syn_graph_config = syn_graph_config
    model = create_model(to_device=False).to(rank)
    model = CustomDDP(model, device_ids=[rank], find_unused_parameters=True)
    # model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    if len(cfg.dataset_multi.name_list) > 0:
        multi_graph = True
    else:
        multi_graph = False
    for key in optim_lr_dict:
        optim_lr_dict[key] = optim_lr_dict[key] * world_size * cfg.train.batch_size
    optimizer = create_optimizer(
        create_param_groups(
            model,
            nodes_num_dict,
            # shared_graph_dataset_dict,
            multi_graph,
            min_lr=cfg.optim.dataset_min_lr * world_size * cfg.train.batch_size,
            max_lr=cfg.optim.dataset_max_lr * world_size * cfg.train.batch_size,
            base_lr=cfg.optim.base_lr * world_size * cfg.train.batch_size,
            lr_dict=optim_lr_dict,
        ),
        new_optimizer_config(cfg),
    )

    lr_log = {}
    logged_names = set()  # Set to keep track of logged names

    if cfg.rank == 0:
        for group in optimizer.param_groups:
            group_name = group.get(
                "name", "base_perceiver"
            )  # Ensure there's a default name
            # Check if the name has already been logged
            if group_name not in logged_names:
                lr_log[f"LR/{group_name}"] = group["lr"]
                logged_names.add(group_name)  # Mark this name as logged
        # Create a PrettyTable
        table = PrettyTable()
        table.field_names = ["Group Name", "Learning Rate"]

        # Fill the table with data
        for name, lr in lr_log.items():
            table.add_row([name.split("/")[-1], lr])
        print(table)

    # optimizer = create_optimizer(model.parameters(), new_optimizer_config(cfg))
    scheduler = create_scheduler(optimizer, new_scheduler_config(cfg))

    if rank == 0:
        cfg.params = params_count(model)
        print("Number of parameters: ", cfg.params)
        print("Number of parameters without MLP: ",params_count_no_mlp(model))
    if len(cfg.dataset_multi.name_list) > 0:
        train_dict[cfg.train.mode](
            train_loader, model, optimizer, scheduler, start_epoch=0
        )
    else:
        train_dict[cfg.train.mode](train_loader, model, optimizer, scheduler)
    cleanup()


# if __name__ == "__main__":
from custom_modules.memory_monitor import MemoryMonitor
import resource
from custom_modules.loader.utils import GraphSAINTRandomWalkSampler_custom
from custom_modules.loader.master_loader import (
    load_dataset_from_pt,
    load_dataset_from_pt_network_repo,
    load_dataset_from_pt_syn,
)
from custom_modules.loader.node_loader import NodeDataset
from torch_geometric.loader.cluster import ClusterData
from custom_modules.loader.dataset.network_repository import NetworkRepository
from prettytable import PrettyTable

if __name__ == "__main__":
    # main(0, 1)
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (32368, rlimit[1]))
    torch.multiprocessing.set_sharing_strategy("file_system")

    world_size = torch.cuda.device_count()
    # mp.spawn(main, args=(world_size,), nprocs=world_size)

    # Load cmd line args
    args = parse_args()
    # Load config file
    set_cfg(cfg)
    load_cfg(cfg, args)

    manager = mp.Manager()
    cfg.dataset_config_list = []
    graph_dataset_dict = {}
    graph_ratio_in_epoch = {}
    optim_lr_dict = {}
    for dataset_name in cfg.dataset_multi.name_list:
        dataset_cfg = getattr(cfg, dataset_name)
        dataset_cfg.enable = True
        cfg.dataset_config_list.append(dataset_cfg)
        graph_dataset_dict[dataset_cfg.dataset_name] = dataset_cfg
        graph_ratio_in_epoch[dataset_cfg.dataset_name] = (
            dataset_cfg.graph_ratio_in_epoch
        )
        optim_lr_dict[dataset_cfg.dataset_name] = dataset_cfg.manual_lr

    shared_graph_dataset_dict = {}
    nodes_num_dict = {}
    shared_graph_ratio_in_epoch = {}
    num_nodes_indv_graphs_dict = {}

    # Create a PrettyTable object
    table = PrettyTable()

    # Define the columns
    table.field_names = [
        "Index",
        "Dataset Name",
        "Number of Graphs",
        "Number of Nodes",
        "Feat Dim",
        "Num Classes",
        "Graph Ratio",
    ]
    counter = 1
    for key, value in graph_dataset_dict.items():
        data_to_add = load_dataset_from_pt(value)

        if data_to_add is None:
            data_to_add = load_dataset_from_pt_network_repo(value)
            num_graphs = len(data_to_add)
            num_nodes = 0
            for idx in range(len(data_to_add)):
                dat = data_to_add[idx]
                dat.node_id = torch.tensor(list(range(len(dat.y))), dtype=torch.long)
                dat.x = dat.x.float()
                if isinstance(dat.dataset_name, str):
                    dat.dataset_name = [dat.dataset_name]
                    dat.dataset_task_name = [dat.dataset_task_name]
                if len(dat.x.shape) == 1:
                    dat.x = dat.x.unsqueeze(-1)
                num_nodes += dat.x.shape[0]
                shared_graph_dataset_dict[f"{key}_graph_{idx}"] = dat
                shared_graph_ratio_in_epoch[f"{key}_graph_{idx}"] = (
                    graph_ratio_in_epoch[key]
                )
            num_nodes = num_nodes / num_graphs

            table.add_row(
                [
                    counter,
                    graph_dataset_dict[key].dataset_name,
                    num_graphs,
                    num_nodes,
                    graph_dataset_dict[key].feat_dim,
                    graph_dataset_dict[key].task_dim,
                    graph_dataset_dict[key].graph_ratio_in_epoch,
                ]
            )
            counter += 1
            nodes_num_dict[key] = data_to_add.data.num_nodes
            num_nodes_indv_graphs_dict[key] = data_to_add.data.num_nodes
        else:
            data_to_add = data_to_add.data
            # shared_graph_dataset_dict[key] = data_to_add
            # shared_graph_ratio_in_epoch[key] = graph_ratio_in_epoch[key]
            table.add_row(
                [
                    counter,
                    graph_dataset_dict[key].dataset_name,
                    1,
                    data_to_add.num_nodes,
                    graph_dataset_dict[key].feat_dim,
                    graph_dataset_dict[key].task_dim,
                    graph_dataset_dict[key].graph_ratio_in_epoch,
                ]
            )
            counter += 1
            nodes_num_dict[key] = data_to_add.num_nodes
            num_nodes_indv_graphs_dict[key] = data_to_add.num_nodes
            if data_to_add.num_nodes > 1000000:
                del data_to_add
                # load partitions
                print("loading partitions")
                for num_part in range(2):
                    data_to_add = load_dataset_from_pt(value, num_part)
                    if data_to_add is not None:
                        shared_graph_dataset_dict[f"{key}_num_part_{num_part}"] = (
                            data_to_add
                        )
                        shared_graph_ratio_in_epoch[f"{key}_num_part_{num_part}"] = (
                            graph_ratio_in_epoch[key]
                        )
                    else:
                        print("partition not available so loading original dataset")
                        data_to_add = load_dataset_from_pt(value)
                        shared_graph_dataset_dict[key] = data_to_add
                        shared_graph_ratio_in_epoch[key] = graph_ratio_in_epoch[key]
                        break
            else:
                shared_graph_dataset_dict[key] = data_to_add
                shared_graph_ratio_in_epoch[key] = graph_ratio_in_epoch[key]

    syn_graph_config = []
    # load synthetic data and make its config and assign to syntheitc data stuff
    if cfg.dataset_multi.use_synthetic:

        syntheitc_pt_files = glob.glob(
            cfg.dataset_multi.synthetic_data_dir[0] + "*/eigen_processed/*"
        )
        print("loading synthetic graphs: ", len(syntheitc_pt_files))
        for filename in syntheitc_pt_files:
            dataset = load_dataset_from_pt_syn(filename)
            dataset_name = dataset.data.dataset_name[0]
            dataset.data.x = dataset.data.x.float()
            syn_graph_config.append(
                {
                    "dataset_name": dataset_name,
                    "task": "node",
                    "task_type": "classification",
                    "loss_fun": "cross_entropy",
                    "task_dim": dataset.data.config["num_clusters"],
                    "feat_dim": dataset.data.config["feature_dim"],
                    "node_feat_encoder_name": "LinearNode",
                    "graph_ratio_in_epoch": 0.2,
                }
            )
            shared_graph_dataset_dict[dataset_name] = dataset.data
            shared_graph_ratio_in_epoch[dataset_name] = 0.2
            nodes_num_dict[dataset_name] = dataset.data.num_nodes

    cfg.dataset_multi.syn_graph_config = syn_graph_config
    print(len(shared_graph_dataset_dict))

    # Print the table
    print(table)
    shared_graph_dataset_dict = manager.dict(shared_graph_dataset_dict)
    # Create a managed dictionary with loaded data
    # shared_graph_dataset_dict = manager.dict(
    #     {
    #         key: load_dataset_from_pt(value).data
    #         for key, value in graph_dataset_dict.items()
    #     }
    # )

    # samp_list = []
    # for i, key in enumerate(shared_graph_dataset_dict.keys()):
    #     print(i, key)

    #     samp_list.append(
    #         manager.dict(
    #             {
    #                 key: GraphSAINTRandomWalkSampler_custom(
    #                     shared_graph_dataset_dict[key],
    #                     batch_size=cfg.model.hop_cutoff,  # Controls the size of the subgraph
    #                     walk_length=cfg.model.hop_cutoff,  # Length of the random walks
    #                     num_steps=5,  # Number of steps (subgraphs) to sample
    #                     sample_coverage=0,  # Set to 0 for no specific coverage; adjust based on needs
    #                     save_dir=None,  # Temporary directory to save the walks
    #                 )
    #             }
    #         )
    #     )

    # print('heyo')
    # print(len(shared_graph_dataset_dict.keys()))
    saint_samplers = manager.dict(
        {
            key: GraphSAINTRandomWalkSampler_custom(
                shared_graph_dataset_dict[key],
                batch_size=cfg.model.hop_cutoff,  # Controls the size of the subgraph
                walk_length=cfg.model.hop_cutoff,  # Length of the random walks
                num_steps=5,  # Number of steps (subgraphs) to sample
                sample_coverage=0,  # Set to 0 for no specific coverage; adjust based on needs
                save_dir=None,  # Temporary directory to save the walks
            )
            for i, key in enumerate(shared_graph_dataset_dict.keys())
        }
    )
    node_dataset_loaded_list_train = []
    for key in shared_graph_dataset_dict.keys():
        # dataset_mask = shared_graph_dataset_dict[key]["train_mask"]
        # dict_for_node_dataset = {
        #     "node_id": shared_graph_dataset_dict[key].node_id[dataset_mask],
        #     "y": shared_graph_dataset_dict[key].y[dataset_mask],
        #     "node_dataset_name": shared_graph_dataset_dict[key].dataset_name[0],
        #     "dataset_task_name": shared_graph_dataset_dict[key].dataset_task_name[0],
        #     "main_graph_dict_key": key,
        # }

        # use 100% of labels
        if hasattr(shared_graph_dataset_dict[key], "train_mask"):
            if len(shared_graph_dataset_dict[key].y.shape) == 2:
                dataset_mask = torch.ones_like(
                    shared_graph_dataset_dict[key]["train_mask"], dtype=torch.bool
                )
            else:
                dataset_mask = torch.ones_like(
                    shared_graph_dataset_dict[key]["train_mask"], dtype=torch.bool
                )
                dataset_mask[shared_graph_dataset_dict[key].y == -1] = False
        else:
            dataset_mask = torch.ones_like(
                shared_graph_dataset_dict[key].y, dtype=torch.bool
            )
            dataset_mask[shared_graph_dataset_dict[key].y == -1] = False

        dict_for_node_dataset = {
            "node_id": shared_graph_dataset_dict[key].node_id[dataset_mask],
            "y": shared_graph_dataset_dict[key].y[dataset_mask],
            "node_dataset_name": shared_graph_dataset_dict[key].dataset_name[0],
            "dataset_task_name": shared_graph_dataset_dict[key].dataset_task_name[0],
            "main_graph_dict_key": key,
            "graph_ratio_in_epoch": shared_graph_ratio_in_epoch[key],
        }
        node_dataset_loaded_list_train.append(
            NodeDataset(dict_for_node_dataset, mask="train_mask")
        )
    node_dataset_loaded_list_train = node_dataset_loaded_list_train
    # node_dataset_loaded_list_train = manager.list(node_dataset_loaded_list_train)

    mp.spawn(
        main,
        args=(
            world_size,
            shared_graph_dataset_dict,
            saint_samplers,
            node_dataset_loaded_list_train,
            nodes_num_dict,
            syn_graph_config,
            optim_lr_dict,
        ),
        nprocs=world_size,
    )
