# -*- coding: UTF-8 -*-


import torch
import tqdm
import torch.nn.functional as F
import torch.distributed as dist


# https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb
# test using a knn monitor
@torch.no_grad()
def knn_monitor(memory_features,
                memory_labels,
                test_features,
                test_labels,
                knn_k,
                knn_t):
    """
    Performs k-nearest neighbor (KNN) monitoring on a set of test features and labels using a memory bank of features and labels.
    
    Args:
        memory_features (torch.Tensor): A tensor of features from the memory bank.
        memory_labels (torch.Tensor): A tensor of labels corresponding to the memory features.
        test_features (torch.Tensor): A tensor of features from the test set.
        test_labels (torch.Tensor): A tensor of labels corresponding to the test features.
        knn_k (int): The number of nearest neighbors to consider.
        knn_t (float): The temperature parameter for the KNN prediction.
    
    Returns:
        float: The top-1 accuracy of the KNN predictions on the test set.
    """
    
    classes = len(torch.unique(memory_labels))
    # generate feature bank
    # [D, N]
    # feature_bank = memory_features.t().contiguous()
    # [N]
    pred_labels = knn_predict(test_features, memory_features, memory_labels, classes, knn_k, knn_t)

    top1 = (pred_labels[:, 0] == test_labels).float().mean()

    return top1


# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict_internal(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    """
    Performs k-nearest neighbor (KNN) prediction on the given feature vector using the provided feature bank and labels.
    
    Args:
        feature (torch.Tensor): The feature vector to predict on, with shape [B, D].
        feature_bank (torch.Tensor): The feature bank, with shape [N, D].
        feature_labels (torch.Tensor): The labels corresponding to the feature bank, with shape [N].
        classes (int): The number of classes.
        knn_k (int): The number of nearest neighbors to consider.
        knn_t (float): The temperature parameter for the softmax-like weighting of the neighbors.
    
    Returns:
        torch.Tensor: The predicted labels, with shape [B].
    """

    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = feature.mm(feature_bank.T)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels


def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    """
    Performs k-nearest neighbor (KNN) prediction on the given feature tensor.
    
    Args:
        feature (torch.Tensor): The feature tensor to make predictions on.
        feature_bank (torch.Tensor): The feature bank to use for the KNN prediction.
        feature_labels (torch.Tensor): The labels corresponding to the feature bank.
        classes (torch.Tensor): The set of classes to predict.
        knn_k (int): The number of nearest neighbors to consider.
        knn_t (float): The temperature parameter for the softmax prediction.
    
    Returns:
        torch.Tensor: The predicted labels for the input feature tensor.
    """

    split_size = 512
    pred_labels = []
    for f in feature.split(split_size, dim=0):
        pred_labels.append(knn_predict_internal(f, feature_bank, feature_labels, classes, knn_k, knn_t))
    return torch.cat(pred_labels, dim=0)
