# import numpy as np
# import random
import torch
import torch.distributed as dist

import os
import time
from pathlib import Path
import logging
import subprocess

def init_environ(cfg):

    # init distributed parallel
    if cfg.ddp.launcher == 'slurm':
        # one or multiple gpus
        _init_dist_slurm('nccl', cfg, cfg.ddp.port)
    elif cfg.ddp.launcher == 'pytorch':
        _init_dist_pytorch('nccl', cfg)
    else:
        # one gpu
        cfg.world_size = 1
        cfg.gpu_id = 0
        cfg.rank = 0
        cfg.distributed = False
        return 

    cfg.distributed = True 
    torch.distributed.barrier()
    setup_for_distributed(cfg.rank == 0)

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def _init_dist_pytorch(backend, cfg):
    cfg.rank = int(os.environ['RANK'])
    # os.environ['LOCAL_RANK']
    cfg.world_size = int(os.environ['WORLD_SIZE'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(cfg.rank % num_gpus)
    
    cfg.gpu_id = cfg.rank % num_gpus
    dist.init_process_group(backend=backend)
    print(f'Distributed training on {cfg.rank}/{cfg.world_size}')
        
def _init_dist_slurm(backend, cfg, port=None):
    """Initialize slurm distributed training environment.
    If argument ``port`` is not specified, then the master port will be system
    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
    environment variable, then a default port ``29500`` will be used.
    Args:
        backend (str): Backend of torch.distributed.
        port (int, optional): Master port. Defaults to None.
    """
    proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    # print(proc_id, num_gpus)
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    # specify master port
    if port is not None:
        os.environ['MASTER_PORT'] = str(port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
    # use MASTER_ADDR in the environment variable if it already exists
    if 'MASTER_ADDR' not in os.environ:
        os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
    os.environ['RANK'] = str(proc_id)
    cfg.world_size = ntasks
    cfg.gpu_id = proc_id % num_gpus
    cfg.rank = proc_id

    dist.init_process_group(backend=backend)
    print(f'Distributed training on {proc_id}/{ntasks}')
