import os
import torch
import numpy as np
import torch.distributed as dist
from collections.abc import Iterable
from torch import nn


def update_metric(running_metric, curr_metric, smoothing_factor=0.9):
    if running_metric == 0:
        return curr_metric
    else:
        return smoothing_factor * running_metric + (1 - smoothing_factor) * curr_metric

        
def convert_to_cuda(inputs):
    if isinstance(inputs, list):
        inputs = [input.cuda() for input in inputs]
    elif torch.is_tensor(inputs):
        inputs = inputs.cuda()
    else:
        raise ValueError("Invalid input type")
    return inputs


def convert_to_device(inputs, device):
    if isinstance(inputs, list):
        inputs = [input.to(device) for input in inputs]
    elif torch.is_tensor(inputs):
        inputs = inputs.to(device)
    else:
        raise ValueError("Invalid input type")
    return inputs


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def get_batch_size(batch_size):
    world_size, rank = get_world_size(), get_rank()
    batch_size_per_gpu, remainder = batch_size // world_size, batch_size % world_size
    if (remainder == 0) or (rank != world_size - 1): 
        return batch_size_per_gpu
    else:
        return batch_size_per_gpu + remainder # Allocate last rank gpu with more samples


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def has_batchnorms(model):
    bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
    for _, module in model.named_modules():
        if isinstance(module, bn_types):
            return True
    return False


def fix_random_seeds(seed=25):
    """
    Fix random seeds.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def init_distributed_mode(args):
    # launched with torch.distributed.launch
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    
    dist.init_process_group(
        backend="nccl",
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
    )

    torch.cuda.set_device(args.gpu)
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    dist.barrier()
    # setup_for_distributed(args.rank == 0)
