"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""

# from utils.lib import *
import pickle
import os
import torch
import torch.distributed as dist
import subprocess as sp
from datetime import timedelta
import numpy as np
import random
import logging
from tqdm import tqdm


def iter_tqdm(item):
    if is_main_process():
        return tqdm(item, ascii=True)
    else:
        return item


def dist_init(args, distributed=True):
    if torch.distributed.is_initialized():
        print("Distributed system has already been initialized")
        print("Skipping dist init ....")
        return args
    args.n_gpu_per_node = torch.cuda.device_count() # local size
    args.local_size = args.n_gpu_per_node
    print("Init Distributed Training")
    if distributed:
        if 'OMPI_COMM_WORLD_SIZE' in os.environ:
            master_addr = os.environ.get("MASTER_ADDR", 'localhost')
            master_port = os.environ.get("MASTER_PORT", 12875)
            master_uri = f"tcp://{master_addr}:{master_port}" #if master_addr else 'localhost'
            world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
            world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
            local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
            args.num_gpus = world_size
            args.word_size = args.num_gpus
            args.distributed = True
            args.local_rank = local_rank
            args.rank = world_rank
            
            args.num_nodes = world_size // args.local_size  
            args.node_id = int(args.rank / args.local_size)
            if args.distributed:
                logging.info(
                    f"Init distributed training on "
                    f"local rank {args.local_rank}, "
                    f"global rank {world_rank}")
                torch.cuda.set_device(args.local_rank)
                dist.init_process_group(
                    backend='nccl',
                    init_method=master_uri,
                    world_size=world_size,
                    rank=world_rank,
                    timeout=timedelta(hours=5),  # 5 hrs
                )
                synchronize()
        elif 'WORLD_SIZE' in os.environ:
            master_addr = os.environ.get("MASTER_ADDR", 'localhost')
            master_port = os.environ.get("MASTER_PORT", 12875)
            master_uri = f"tcp://{master_addr}:{master_port}" #if master_addr else 'localhost'
            args.num_gpus = int(
                os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
            args.word_size = args.num_gpus
            local_rank = int(os.environ['LOCAL_RANK'])
            rank = int(os.environ['RANK'])
            args.distributed = True  # args.num_gpus > 1
            args.local_rank = local_rank
            args.rank = rank
            args.num_nodes = args.word_size // args.local_size  # hardcoded
            args.node_id = int(args.rank / args.local_size)
            if args.distributed:
                logging.info(
                    f"Init distributed training on "
                    f"local rank {args.local_rank}, "
                    f"global rank {rank}")
                torch.cuda.set_device(args.local_rank)
                dist.init_process_group(
                    backend='nccl', 
                    init_method=master_uri,
                    world_size=args.word_size,
                    rank=rank,
                    timeout=timedelta(hours=5),  # 5 hrs
                )
                synchronize()
        else:
            print("distributed training ... presumbly debug with 1 GPU")
            args.num_gpus = 1
            args.distributed = True
            args.num_nodes = 1
            args.word_size = args.num_gpus
            args.local_rank = 0
            args.rank = 0
            args.num_nodes = 1
            args.node_id = 0
            os.environ['LOCAL_RANK'] = str(args.local_rank)
            master_addr = os.environ.get("MASTER_ADDR", 'localhost')
            master_port = os.environ.get("MASTER_PORT", 12875)
            master_uri = f"tcp://{master_addr}:{master_port}" #if master_addr else 'localhost'
            logging.info(
                f"Init distributed training on "
                f"local rank {args.local_rank}")
            torch.cuda.set_device(args.local_rank)
            dist.init_process_group(
                backend='nccl', 
                init_method=master_uri,
                world_size=args.word_size,
                rank=args.rank,
                timeout=timedelta(hours=5),  # 5 hrs
            )
            # raise ValueError(
            #     "Unable to init torch.distributed. Did not find WORLD_SIZE or OMPI_COMM_WORLD_SIZE in os.environ")
    else:
        print("no distributed training ...")
        # no distributed training
        args.num_gpus = torch.cuda.device_count()
        args.word_size = args.num_gpus
        args.num_nodes = 1
        args.distributed = False
    # Setting seed
    # set_seed(args)
    return args


def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
    return random.randint(min_seed_value, max_seed_value)


def validate_seed(seed):
    max_seed_value = np.iinfo(np.uint32).max
    min_seed_value = np.iinfo(np.uint32).min

    if seed is None:
        env_seed = os.environ.get("PL_GLOBAL_SEED")
        if env_seed is None:
            seed = _select_seed_randomly(min_seed_value, max_seed_value)
            print(f"No seed found, seed set to {seed}")
        else:
            try:
                seed = int(env_seed)
            except ValueError:
                seed = _select_seed_randomly(min_seed_value, max_seed_value)
                print(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
    elif not isinstance(seed, int):
        seed = int(seed)

    if not (min_seed_value <= seed <= max_seed_value):
        print(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
        seed = _select_seed_randomly(min_seed_value, max_seed_value)
    return seed


def set_seed(args):
    seed, n_gpu = args.seed, args.num_gpus
    if args.do_train and args.fix_dist_seed:
        this_seed = validate_seed((args.rank+1)* seed)
    else:
        this_seed = validate_seed(seed)
    print(f"Using seed {this_seed} for rank {args.rank}")
    random.seed(this_seed)
    np.random.seed(this_seed)
    torch.random.manual_seed(this_seed)
    # torch.cuda.manual_seed(this_seed)
    if args.rank == 0:
        print(f"Using seed {this_seed} for torch.cuda")
        torch.cuda.manual_seed_all(this_seed)


def get_world_size():
    if 'WORLD_SIZE' in os.environ:
        return int(os.environ['WORLD_SIZE'])
    return int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))


def get_rank():
    if 'RANK' in os.environ:
        return int(os.environ['RANK'])
    return int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))


def get_local_rank():
    if 'LOCAL_RANK' in os.environ:
        return int(os.environ['LOCAL_RANK'])
    return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))


def get_local_size():
    if 'LOCAL_SIZE' in os.environ:
        return int(os.environ['LOCAL_SIZE'])
    return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE', '1'))


def is_main_process():
    if get_rank()==0:
        try:
            if torch.distributed.get_rank()==0:
                return True
            else:
                return False
        except RuntimeError:
            return True
    else:
        return False

# def is_main_process():
#     return get_rank() == 0


def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = get_world_size()
    if world_size == 1:
        return
    t = torch.randn((), device='cuda')
    dist.all_reduce(t)
    torch.cuda.synchronize()
    return
    # dist.barrier()


def gather_on_master(data):
    """Same as all_gather, but gathers data on master process only, using CPU.
    Thus, this does not work with NCCL backend unless they add CPU support.

    The memory consumption of this function is ~ 3x of data size. While in
    principal, it should be ~2x, it's not easy to force Python to release
    memory immediately and thus, peak memory usage could be up to 3x.
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    # trying to optimize memory, but in fact,
    # it's not guaranteed to be released
    del data
    storage = torch.ByteStorage.from_buffer(buffer)
    del buffer
    tensor = torch.ByteTensor(storage)

    # obtain Tensor size of each rank
    local_size = torch.LongTensor([tensor.numel()])
    size_list = [torch.LongTensor([0]) for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    if local_size != max_size:
        padding = torch.ByteTensor(size=(max_size - local_size,))
        tensor = torch.cat((tensor, padding), dim=0)
        del padding

    if is_main_process():
        tensor_list = []
        for _ in size_list:
            tensor_list.append(torch.ByteTensor(size=(max_size,)))
        dist.gather(tensor, gather_list=tensor_list, dst=0)
        del tensor
    else:
        dist.gather(tensor, gather_list=[], dst=0)
        del tensor
        return

    data_list = []
    for tensor in tensor_list:
        buffer = tensor.cpu().numpy().tobytes()
        del tensor
        data_list.append(pickle.loads(buffer))
        del buffer

    return data_list


def all_gather(data):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    # obtain Tensor size of each rank
    local_size = torch.LongTensor([tensor.numel()]).to("cuda")
    size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    # we pad the tensor because torch all_gather does not support
    # gathering tensors of different shapes
    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
    if local_size != max_size:
        padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
        tensor = torch.cat((tensor, padding), dim=0)
    dist.all_gather(tensor_list, tensor)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list


def reduce_dict(input_dict, average=True):
    """
    Args:
        input_dict (dict): all the values will be reduced
        average (bool): whether to do average or sum
    Reduce the values in the dictionary from all processes
    so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    input_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return input_dict
    with torch.no_grad():
        names = []
        values = []
        # sort the keys so that they are consistent across processes
        for k in sorted(input_dict.keys()):
            names.append(k)
            values.append(input_dict[k])
        values = torch.stack(values, dim=0)
        dist.reduce(values, dst=0)
        if dist.get_rank() == 0 and average:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            values /= world_size
        reduced_dict = {k: v for k, v in zip(names, values)}
    return reduced_dict


class NoOp(object):
    """ useful for distributed training No-Ops """
    def __getattr__(self, name):
        return self.noop

    def noop(self, *args, **kwargs):
        return


def decode_to_str(x):
    try:
        return x.decode('utf-8')
    except UnicodeDecodeError:
        return x.decode('latin-1')


def cmd_run(list_cmd,
            return_output=False,
            env=None,
            working_dir=None,
            stdin=sp.PIPE,
            shell=False,
            dry_run=False,
            silent=False,
            process_input=None,
            stdout=None,
            ):
    if not silent:
        logging.info(
            'start to cmd run: {}'.format(' '.join(map(str, list_cmd))))
        if working_dir:
            logging.info(working_dir)
    # if we dont' set stdin as sp.PIPE, it will complain the stdin is not a tty
    # device. Maybe, the reson is it is inside another process.
    # if stdout=sp.PIPE, it will not print the result in the screen
    e = os.environ.copy()
    if 'SSH_AUTH_SOCK' in e:
        del e['SSH_AUTH_SOCK']
    if working_dir:
        os.makedirs(working_dir, exist_ok=True)
    if env:
        for k in env:
            e[k] = env[k]
    if dry_run:
        # we need the log result. Thus, we do not return at teh very beginning
        return
    if not return_output:
        # if env is None:
        #     p = sp.Popen(list_cmd, stdin=sp.PIPE, cwd=working_dir)
        # else:
        p = sp.Popen(' '.join(list_cmd) if shell else list_cmd,
                     stdin=stdin,
                     env=e,
                     shell=shell,
                     stdout=stdout,
                     cwd=working_dir)
        message = p.communicate(input=process_input)
        if p.returncode != 0:
            raise ValueError(message)
        return message
    else:
        if shell:
            message = sp.check_output(
                ' '.join(list_cmd),
                env=e,
                cwd=working_dir,
                shell=True)
        else:
            message = sp.check_output(list_cmd,
                                      env=e,
                                      cwd=working_dir,
                                      )
        if not silent:
            logging.info('finished the cmd run')
        return decode_to_str(message)


def broadcast(data, src=0):
    output = dist.broadcast(data, src=src)
    return output
