# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# A script to run multinode training with submitit.
# --------------------------------------------------------

import argparse
import os
import uuid
from pathlib import Path

import main_prob as classification
import submitit


def parse_args():
    classification_parser = classification.get_args_parser()
    parser = argparse.ArgumentParser("Submitit for MAE probe", parents=[classification_parser])
    parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
    parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
    parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job")
    parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")

    parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
    parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs")
    parser.add_argument("--aws", action='store_true', help="AWS setup")
    parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
    parser.add_argument('--exclude', default="", type=str, help="Exclude nodes")
    return parser.parse_args()


def get_shared_folder() -> Path:
    user = os.getenv("USER")
    if Path("/checkpoint/").is_dir():
        p = Path(f"/checkpoint/{user}/mae")
        p.mkdir(exist_ok=True)
        return p
    raise RuntimeError("No shared folder available")


def get_init_file():
    # Init file must not exist, but it's parent dir must exist.
    os.makedirs(str(get_shared_folder()), exist_ok=True)
    init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
    if init_file.exists():
        os.remove(str(init_file))
    return init_file


class Trainer(object):
    def __init__(self, args):
        self.args = args

    def __call__(self):
        import main_prob as classification

        self._setup_gpu_args()
        if self.args.aws:
            self._setup_aws()
        else:
            self._setup_fair()
        classification.main(self.args)

    def checkpoint(self):
        import os
        import submitit

        self.args.dist_url = get_init_file().as_uri()
        print("Requeuing ", self.args)
        empty_trainer = type(self)(self.args)
        return submitit.helpers.DelayedSubmission(empty_trainer)

    def _setup_gpu_args(self):
        import submitit
        from pathlib import Path

        job_env = submitit.JobEnvironment()
        self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
        self.args.log_dir = self.args.output_dir
        self.args.gpu = job_env.local_rank
        self.args.rank = job_env.global_rank
        self.args.world_size = job_env.num_tasks
        print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")

    def _setup_aws(self):
        # specifically for aws
        # os.environ["GLOO_SOCKET_IFNAME"] = 'ens32,ens65,ens130,ens163'
        # os.environ["NCCL_SOCKET_IFNAME"] = 'ens32'
        # os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
        os.environ["FI_EFA_MR_CACHE_ENABLE"] = '1'
        os.environ["FI_EFA_USE_DEVICE_RDMA"] = '1'
        os.environ["FI_OFI_RXR_INLINE_MR_ENABLE"] = '1'
        os.environ["FI_OFI_RXR_RX_COPY_UNEXP"] = '1'
        os.environ["FI_OFI_RXR_RX_COPY_OOO"] = '1'
        os.environ["FI_PROVIDER"] = 'efa'
        os.environ["NCCL_ALGO"] = 'tree'
        os.environ["NCCL_PROTO"] = 'simple'
        # os.environ["NCCL_BLOCKING_WAIT"] = '1'
        os.environ["NCCL_ASYNC_ERROR_HANDLING"] = '1'
        os.environ["NCCL_DEBUG"] = 'INFO'
        os.environ["NCCL_TREE_THRESHOLD"] = '0'
        os.environ["NCCL_NET_SHARED_BUFFERS"] = '0'
        os.environ["NCCL_TOPO_FILE"] = '/usr/local/cuda-11.3/efa/share/aws-ofi-nccl/xml/p4d-24xl-topo.xml'
        os.environ["RDMAV_FORK_SAFE"] = '1'
        
    def _setup_fair(self):
        os.environ["GLOO_SOCKET_IFNAME"] = ''
        os.environ["NCCL_SOCKET_IFNAME"] = ''
        # os.environ["NCCL_DEBUG"] = 'INFO'
        # os.environ["NCCL_BLOCKING_WAIT"] = '1'
        os.environ["NCCL_ASYNC_ERROR_HANDLING"] = '1'
        # os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
        return


def main():
    args = parse_args()
    if args.job_dir == "":
        args.job_dir = get_shared_folder() / "%j"

    # Note that the folder will depend on the job_id, to easily track experiments
    executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)

    num_gpus_per_node = args.ngpus
    nodes = args.nodes
    timeout_min = args.timeout

    partition = args.partition
    kwargs = {}
    if args.use_volta32:
        kwargs['slurm_constraint'] = 'volta32gb'
    if args.comment:
        kwargs['slurm_comment'] = args.comment
    if args.exclude:
        kwargs['exclude'] = args.exclude

    executor.update_parameters(
        mem_gb=(0 if args.aws else 40 * num_gpus_per_node),
        gpus_per_node=num_gpus_per_node,
        tasks_per_node=num_gpus_per_node, # one task per GPU
        cpus_per_task=10,
        nodes=nodes,
        timeout_min=timeout_min,
        # Below are cluster dependent parameters
        slurm_partition=partition,
        slurm_signal_delay_s=120,
        **kwargs
    )

    executor.update_parameters(name="mae_ttt")

    args.dist_url = get_init_file().as_uri()
    args.output_dir = args.job_dir

    trainer = Trainer(args)
    job = executor.submit(trainer)

    # print("Submitted job_id:", job.job_id)
    print(job.job_id)


if __name__ == "__main__":
    main()
