# -*- coding: UTF-8 -*-
"""
Utility functions for the application.

This module provides a set of utility functions that are used throughout the
application. These functions handle common tasks such as logging, configuration
management, and data processing.
"""


from torchvision import transforms
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from munkres import Munkres
from sklearn import metrics
import warnings
import matplotlib.pyplot as plt
import tqdm
import torch.distributed as dist

from utils.ops import convert_to_cuda, is_root_worker
from .knn_monitor import knn_monitor


@torch.no_grad()
def extract_features(extractor, loader):
    """
    Extracts features and labels from a data loader using a feature extractor.
    
    Args:
        extractor (torch.nn.Module): The feature extractor model.
        loader (torch.utils.data.DataLoader): The data loader to extract features from.
    
    Returns:
        tuple: A tuple containing the extracted features and labels.
    """

    extractor.eval()

    local_features = []
    local_labels = []
    for inputs in tqdm.tqdm(loader, disable=not is_root_worker()):
        images, labels = convert_to_cuda(inputs)
        local_labels.append(labels)
        local_features.append(extractor(images))
    local_features = torch.cat(local_features, dim=0)
    local_labels = torch.cat(local_labels, dim=0)

    indices = torch.Tensor(list(iter(loader.sampler))).long().cuda()

    features = torch.zeros(len(loader.dataset), local_features.size(1)).cuda()
    all_labels = torch.zeros(len(loader.dataset)).cuda()
    counts = torch.zeros(len(loader.dataset)).cuda()
    features.index_add_(0, indices, local_features)
    all_labels.index_add_(0, indices, local_labels.float())
    counts[indices] = 1.

    if dist.is_initialized():
        dist.all_reduce(features, op=dist.ReduceOp.SUM)
        dist.all_reduce(all_labels, op=dist.ReduceOp.SUM)
        dist.all_reduce(counts, op=dist.ReduceOp.SUM)
    # account for the few samples that are computed twice
    labels = (all_labels / counts).long()
    features /= counts[:, None]

    return features, labels


# @torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs an all-gather operation on the provided tensor, concatenating the results from all processes in the distributed group.
    
    This function is a wrapper around `torch.distributed.all_gather` that handles the case where the `all_gather` operation has no gradient. It creates a list of tensors the same size as the input tensor, calls `all_gather` to fill the list, and then concatenates the list into a single output tensor.
    
    Args:
        tensor (torch.Tensor): The input tensor to be gathered.
    
    Returns:
        torch.Tensor: The concatenated tensor containing the gathered results from all processes.
    """


    tensors_gather = [torch.ones_like(tensor)
                      for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


@torch.no_grad()
def shuffling_forward(inputs, encoder):
    """
    Shuffles the input tensors, passes them through an encoder, and then unshuffles the output tensors.
    
    Args:
        inputs (torch.Tensor): The input tensors to be shuffled and encoded.
        encoder (callable): A function that encodes the shuffled input tensors.
    
    Returns:
        torch.Tensor: The encoded tensors, with the original order restored.
    """

    # shuffle for making use of BN
    inputs, idx_unshuffle = _batch_shuffle_ddp(inputs)
    inputs = encoder(inputs)  # keys: NxC
    # undo shuffle
    inputs = _batch_unshuffle_ddp(inputs, idx_unshuffle)
    return inputs


@torch.no_grad()
def _batch_shuffle_ddp(x):
    """
    Batch shuffles the input tensor `x` for making use of BatchNorm in a DistributedDataParallel (DDP) model.
    
    This function gathers the input tensor from all GPUs, shuffles the batch, and returns the shuffled tensor along with the indices needed to restore the original order.
    
    Args:
        x (torch.Tensor): The input tensor to be shuffled.
    
    Returns:
        torch.Tensor: The shuffled input tensor.
        torch.Tensor: The indices needed to restore the original order.
    """


    # gather from all gpus
    batch_size_this = x.shape[0]
    x_gather = concat_all_gather(x)
    batch_size_all = x_gather.shape[0]

    num_gpus = batch_size_all // batch_size_this

    # random shuffle index
    idx_shuffle = torch.randperm(batch_size_all).cuda()

    # broadcast to all gpus
    torch.distributed.broadcast(idx_shuffle, src=0)

    # index for restoring
    idx_unshuffle = torch.argsort(idx_shuffle)

    # shuffled index for this gpu
    gpu_idx = torch.distributed.get_rank()
    idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

    return x_gather[idx_this], idx_unshuffle


@torch.no_grad()
def _batch_unshuffle_ddp(x, idx_unshuffle):
    """
    Undo the batch shuffling performed by a DistributedDataParallel (DDP) model.
    
    This function takes a tensor `x` and an index tensor `idx_unshuffle`, and returns
    the original tensor `x` with the batch shuffling undone. This is useful when the
    model has been trained using DDP, which shuffles the batch across multiple GPUs,
    and you need to access the original unshuffled data.
    
    Args:
        x (torch.Tensor): The tensor to be unshuffled.
        idx_unshuffle (torch.Tensor): The index tensor used to undo the shuffling.
    
    Returns:
        torch.Tensor: The original unshuffled tensor.
    """


    # gather from all gpus
    batch_size_this = x.shape[0]
    x_gather = concat_all_gather(x)
    batch_size_all = x_gather.shape[0]

    num_gpus = batch_size_all // batch_size_this

    # restored index for this gpu
    gpu_idx = torch.distributed.get_rank()
    idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

    return x_gather[idx_this]


@torch.no_grad()
def _momentum_update(q_params, k_params, m):
    """
    Performs a momentum update on the parameters `q_params` and `k_params` using the momentum factor `m`.
    
    The momentum update is calculated as:
    `param_k.data = param_k.data * m + param_q.data * (1. - m)`
    
    This function can handle a single pair of parameters or a list/tuple of parameter pairs.
    """


    if not isinstance(q_params, (list, tuple)):
        q_params, k_params = [q_params, ], [k_params, ]
    for param_q, param_k in zip(q_params, k_params):
        param_k.data = param_k.data * m + param_q.data * (1. - m)


class TwoCropTransform:
    """
    Applies two different image transformations to the same input image.
    
    The `TwoCropTransform` class takes two image transformation functions as input and applies them both to the same input image, returning a list of the two transformed images.
    
    Args:
        transform1 (callable): The first image transformation function to apply.
        transform2 (callable, optional): The second image transformation function to apply. If not provided, `transform1` will be used for both transformations.
    
    Returns:
        list: A list containing the two transformed images.
    """


    def __init__(self, transform1, transform2=None):
        self.transform1 = transform1
        self.transform2 = transform1 if transform2 is None else transform2

    def __call__(self, x):
        return [self.transform1(x), self.transform2(x)]

    def __str__(self):
        return f'transform1 {str(self.transform1)} transform2 {str(self.transform2)}'
