import numpy as np
from typing import List
import torch

# Compute recall@k
def recall_at_k(sims: np.array, k: int) -> float:
    recalls = []
    for i, query_sims in enumerate(sims):
        sorted_indices = np.argsort(query_sims)[::-1]
        if i in sorted_indices[:k]:
            recalls.append(1)
        else:
            recalls.append(0)
    return round(sum(recalls) * 100 / len(recalls), 2)


def get_recalls(sims: np.array, ks: List[int]) -> List[float]:
    return [recall_at_k(sims, k) for k in ks]


# From google-research/composed_image_retrieval
def recall_at_k_labels(sim, query_lbls, target_lbls, k=10):
    distances = 1 - sim
    sorted_indices = torch.argsort(distances, dim=-1).cpu()
    sorted_index_names = np.array(target_lbls)[sorted_indices]
    labels = torch.tensor(
        sorted_index_names
        == np.repeat(np.array(query_lbls), len(target_lbls)).reshape(
            len(query_lbls), -1
        )
    )
    assert torch.equal(
        torch.sum(labels, dim=-1).int(), torch.ones(len(query_lbls)).int()
    )
    return round((torch.sum(labels[:, :k]) / len(labels)).item() * 100, 2)


def get_recalls_labels(
    sims: np.array, query_lbls, target_lbls, ks: List[int]
) -> List[float]:
    return [recall_at_k_labels(sims, query_lbls, target_lbls, k) for k in ks]


def avg_recall_k(sim, target, k):
    recalls = []
    for i in range(sim.shape[0]):
        top_indices = np.argsort(sim[i])[::-1][:k]  # get indices of top k values
        top_labels = target[top_indices]  # get corresponding labels
        correct_label = target[i]
        num_correct = len(
            [label for label in top_labels if label == correct_label]
        )  # count number of correct labels
        total_correct = len([label for label in target if label == correct_label])

        recalls.append(num_correct / total_correct)

    return round(np.mean(recalls) * 100, 2)


def get_avg_recalls(sims: np.array, target: np.array, ks: List[int]) -> List[float]:
    return [avg_recall_k(sims, target, k) for k in ks]
