'''
    1. Copy distutils.setup from https://github.com/Open-Catalyst-Project/ocp/blob/89948582edfb8debb736406d54db9813a5f2c88d/ocpmodels/common/distutils.py#L16
    2. Add OpenMPI multi-node training as Submitit is not supported.
'''

import logging
import os
import subprocess

import torch
import torch.distributed as dist


def setup(config):
    if config["submit"]:
        node_list = os.environ.get("SLURM_STEP_NODELIST")
        if node_list is None:
            node_list = os.environ.get("SLURM_JOB_NODELIST")
        if node_list is not None:
            try:
                hostnames = subprocess.check_output(
                    ["scontrol", "show", "hostnames", node_list]
                )
                config["init_method"] = "tcp://{host}:{port}".format(
                    host=hostnames.split()[0].decode("utf-8"),
                    port=config["distributed_port"],
                )
                nnodes = int(os.environ.get("SLURM_NNODES"))
                ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
                if ntasks_per_node is not None:
                    ntasks_per_node = int(ntasks_per_node)
                else:
                    ntasks = int(os.environ.get("SLURM_NTASKS"))
                    nnodes = int(os.environ.get("SLURM_NNODES"))
                    assert ntasks % nnodes == 0
                    ntasks_per_node = int(ntasks / nnodes)
                if ntasks_per_node == 1:
                    assert config["world_size"] % nnodes == 0
                    gpus_per_node = config["world_size"] // nnodes
                    node_id = int(os.environ.get("SLURM_NODEID"))
                    config["rank"] = node_id * gpus_per_node
                    config["local_rank"] = 0
                else:
                    assert ntasks_per_node == config["world_size"] // nnodes
                    config["rank"] = int(os.environ.get("SLURM_PROCID"))
                    config["local_rank"] = int(os.environ.get("SLURM_LOCALID"))

                logging.info(
                    f"Init: {config['init_method']}, {config['world_size']}, {config['rank']}"
                )

                # ensures GPU0 does not have extra context/higher peak memory
                torch.cuda.set_device(config["local_rank"])

                dist.init_process_group(
                    backend=config["distributed_backend"],
                    init_method=config["init_method"],
                    world_size=config["world_size"],
                    rank=config["rank"],
                )
            except subprocess.CalledProcessError as e:  # scontrol failed
                raise e
            except FileNotFoundError:  # Slurm is not installed
                pass
    elif config["summit"]:
        world_size = int(os.getenv('OMPI_COMM_WORLD_SIZE'))
        world_rank = int(os.getenv('OMPI_COMM_WORLD_RANK'))
        
        # Should be set already
        #get_master = (
        #    "echo $(cat {} | sort | uniq | grep -v batch | grep -v login | head -1)"
        #).format(os.environ["LSB_DJOB_HOSTFILE"])
        #os.environ["MASTER_ADDR"] = str(
        #    subprocess.check_output(get_master, shell=True)
        #)[2:-3]
        #os.environ["MASTER_PORT"] = "23456"
        
        os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
        os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
        
        config["local_rank"] = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
        
        # NCCL and MPI initialization
        dist.init_process_group(
            backend="nccl",
            rank=world_rank,
            world_size=world_size,
            init_method="env://",
        )
    else:
        dist.init_process_group(
            backend=config["distributed_backend"], init_method="env://", 
            rank=config['local_rank'],
            world_size=config['world_size']
        )
        torch.cuda.set_device(config["local_rank"])
    # TODO: SLURM
