# -*- coding: UTF-8 -*-

import math

from torch import nn
import torch
from typing import Union
import torch.distributed as dist
# from torch._six import string_classes
import torch.nn.functional as F
import collections.abc as container_abcs
from PIL import Image

string_classes = str


@torch.no_grad()
def topk_accuracy(output, target, topk=(1,)):
    """Computes the top-k accuracy of the model's predictions.
    
    Args:
        output (torch.Tensor): The model's output logits or probabilities.
        target (torch.Tensor): The ground truth labels.
        topk (tuple of int, optional): The values of k for which to compute the top-k accuracy. Defaults to (1,).
    
    Returns:
        list of float: The top-k accuracy for each value of k in topk.
    """


    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append((correct_k / batch_size).item())
    return res


def concat_all_gather(tensor):
    """
    Concatenates and gathers a tensor across all processes in a distributed environment.
    
    Args:
        tensor (torch.Tensor): The input tensor to be concatenated and gathered.
    
    Returns:
        torch.Tensor: The concatenated tensor gathered from all processes.
    """

    dtype = tensor.dtype
    tensor = tensor.float()
    tensors_gather = [torch.ones_like(tensor)
                      for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    output = output.to(dtype)
    return output


class dataset_with_indices(torch.utils.data.Dataset):
    """
    A PyTorch dataset wrapper that returns the index of each sample along with the sample data.
    
    This dataset can be used to train models that require the index of each sample, such as models that use positional embeddings.
    
    Args:
        dataset (torch.utils.data.Dataset): The underlying dataset to wrap.
    
    Returns:
        A PyTorch dataset that returns a tuple of (sample, index) for each item.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        outs = self.dataset[idx]
        return [outs, idx]


def convert_to_cuda(data):
    r"""
    Converts each data field in the input to a PyTorch tensor, moving it to the GPU if possible.
    
    Args:
        data (Union[torch.Tensor, Mapping, Sequence, Any]): The input data to be converted to tensors.
    
    Returns:
        Union[torch.Tensor, Mapping, Sequence, Any]: The input data with all NumPy arrays converted to PyTorch tensors, with the tensors moved to the GPU if possible.
    """


    elem_type = type(data)
    if isinstance(data, torch.Tensor):
        if data.is_cuda:
            return data
        return data.cuda(non_blocking=True)
    elif isinstance(data, container_abcs.Mapping):
        return {key: convert_to_cuda(data[key]) for key in data}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return elem_type(*(convert_to_cuda(d) for d in data))
    elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
        return [convert_to_cuda(d) for d in data]
    else:
        return data


def is_root_worker():
    """
    Returns True if the current process is the root worker process, False otherwise.
    
    This function checks if the current process is the root worker process by checking the rank of the current process in the distributed training setup. If the rank is 0, it means the current process is the root worker, and the function returns True. Otherwise, it returns False.
    
    The function also sets a `verbose` flag based on the rank of the current process. If the rank is not 0, the `verbose` flag is set to False, indicating that the current process is not the root worker and should not print verbose output.
    """

    verbose = True
    if dist.is_initialized():
        if dist.get_rank() != 0:
            verbose = False
    return verbose


def load_network(state_dict):
    """
    Loads a PyTorch model state dictionary, removing the 'module.' prefix if present.
    
    Args:
        state_dict (str or dict): The state dictionary to load, either as a file path or a dictionary.
    
    Returns:
        dict: The state dictionary with the 'module.' prefix removed from the keys.
    """

    if isinstance(state_dict, str):
        state_dict = torch.load(state_dict, map_location='cpu')
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        namekey = k.replace('module.', '')  # remove `module.`
        new_state_dict[namekey] = v
    return new_state_dict


def convert_to_ddp(modules: Union[list, nn.Module], **kwargs):
    """
    Converts a PyTorch module or list of modules to use Distributed Data Parallel (DDP) training.
    
    If PyTorch distributed is initialized, the module(s) will be wrapped with `torch.nn.parallel.DistributedDataParallel`. Otherwise, the module(s) will be wrapped with `torch.nn.DataParallel`.
    
    Args:
        modules (Union[list, nn.Module]): A PyTorch module or a list of PyTorch modules to be converted to DDP.
        **kwargs: Additional keyword arguments to be passed to `torch.nn.parallel.DistributedDataParallel`.
    
    Returns:
        Union[list, nn.Module]: The input module(s) wrapped with DDP or DataParallel.
    """

    if isinstance(modules, list):
        modules = [x.cuda() for x in modules]
    else:
        modules = modules.cuda()
    if dist.is_initialized():
        device = torch.cuda.current_device()
        if isinstance(modules, list):
            modules = [torch.nn.parallel.DistributedDataParallel(x,
                                                                 device_ids=[device, ],
                                                                 output_device=device,
                                                                 **kwargs) for
                       x in modules]
        else:
            modules = torch.nn.parallel.DistributedDataParallel(modules,
                                                                device_ids=[device, ],
                                                                output_device=device,
                                                                **kwargs)

    else:
        modules = torch.nn.DataParallel(modules)

    return modules
