from typing import Sequence, Tuple

import torch
import torch.nn.functional as F

#  import torchmetrics

from latent_invariances.cka import CudaCKA as CKA

# from latent_invariances.openfaiss import FaissIndex
from latent_invariances.utils.space import LatentSpace

EncPair = Tuple[str, str]


@torch.no_grad()
def evaluate_retrieval(
    latent_space1: LatentSpace, latent_space2: LatentSpace, search_ids: Sequence[str], device: torch.device, k: int = 5
):
    assert latent_space1.encoding_type == latent_space2.encoding_type
    assert latent_space1.vectors.shape[0] == latent_space2.vectors.shape[0]

    # index2: FaissIndex = latent_space2.to_faiss()

    # target_neighbors = index2.search_by_vectors(
    #     query_vectors=latent_space2.vectors.cpu().numpy(), k_most_similar=k, normalize=True
    # )
    # actual_neighbors = index2.search_by_vectors(
    #     query_vectors=latent_space1.vectors.cpu().numpy(), k_most_similar=k, normalize=True
    # )

    # target_neighbors: Mapping[str, Mapping[str, float]] = {
    #     word: topk for word, topk in zip(search_ids, target_neighbors)
    # }
    # actual_neighbors: Mapping[str, Mapping[str, float]] = {
    #     word: topk for word, topk in zip(search_ids, actual_neighbors)
    # }

    # target_words: Mapping[str, Set[str]] = {
    #     search_word: set(target_neighbors[search_word].keys()) for search_word in search_ids
    # }
    # actual_words: Mapping[str, Set[str]] = {
    #     search_word: set(actual_neighbors[search_word].keys()) for search_word in search_ids
    # }

    # topk_jaccard = {
    #     search_word: len(set.intersection(target_words[search_word], actual_words[search_word]))
    #     / len(set.union(target_words[search_word], actual_words[search_word]))
    #     for search_word in search_ids
    # }
    # topk_jaccard = np.mean(list(topk_jaccard.values()))

    # search_word2word2rank = {
    #     search_word: {key: index for index, key in enumerate(word2sim.keys(), start=1)}
    #     for search_word, word2sim in actual_neighbors.items()
    # }
    # mrr = {
    #     search_word: (
    #         #                 word2rank.get(search_word, K)
    #         0
    #         if search_word not in word2rank
    #         else 1 / word2rank[search_word]
    #     )
    #     for search_word, word2rank in search_word2word2rank.items()
    # }
    # mrr = np.mean(list(mrr.values()))

    # print("starting evaluating")
    chunk_size: int = 5000
    num_chunks: int = (len(search_ids) + chunk_size - 1) // chunk_size
    #     linear_cka, mse, cosine_sim, spearman, l1 = [], [], [], [], []
    linear_cka, mse, cosine_sim, l1 = [], [], [], []
    for chunk_latents1, chunk_latents2 in zip(
        latent_space1.vectors.chunk(num_chunks), latent_space2.vectors.chunk(num_chunks)
    ):
        chunk_latents1 = chunk_latents1.cuda()
        chunk_latents2 = chunk_latents2.cuda()
        cka = CKA(device=device)

        # print("start linear_CKA")
        chunk_linear_cka = cka.linear_CKA(chunk_latents1, chunk_latents2).cpu()
        # chunk_rbf_kernel_cka = cka.kernel_CKA(chunk_latents1, chunk_latents2).cpu()
        # print("start cosine_sim")
        chunk_cosine_sim = F.cosine_similarity(chunk_latents1, chunk_latents2).mean().cpu()
        # print("start mse")
        chunk_mse = F.mse_loss(chunk_latents1, chunk_latents2, reduction="mean").cpu()
        # print("start l1")
        chunk_l1 = F.l1_loss(chunk_latents1, chunk_latents2, reduction="mean").cpu()
        # print("start spearman ", chunk_latents1.shape, chunk_latents2.shape)
        # print(chunk_latents1.var(dim=-1).mean(), chunk_latents2.var(dim=-1).mean())
        # chunk_spearman = torchmetrics.functional.spearman_corrcoef(chunk_latents1.T, chunk_latents2.T).mean().cpu()

        _ = chunk_latents1.cpu()
        _ = chunk_latents2.cpu()

        linear_cka.append(chunk_linear_cka)
        # rbf_kernel_cka.append(torch.zeros(1))
        mse.append(chunk_mse)
        cosine_sim.append(chunk_cosine_sim)
        # spearman.append(chunk_spearman)
        l1.append(chunk_l1)

    linear_cka = torch.stack(linear_cka).mean(dim=0).cpu().item()
    # rbf_kernel_cka = torch.stack(rbf_kernel_cka).mean(dim=0).cpu().item()
    mse = torch.stack(mse).mean(dim=0).cpu().item()
    cosine_sim = torch.stack(cosine_sim).mean(dim=0).cpu().item()

    # spearman = torch.stack(spearman).mean(dim=0).cpu().item()
    l1 = torch.stack(l1).mean(dim=0).cpu().item()

    # performance["topk_jaccard"].append(topk_jaccard)
    # performance["mrr"].append(mrr)
    # performance["rbf_kernel_cka"].append(rbf_kernel_cka)

    # print("end evaluating")

    return {
        # "topk_jaccard":,
        # "mrr":,
        "linear_cka": linear_cka,
        # "rbf_kernel_cka":,
        "mse": mse,
        "cosine_sim": cosine_sim,
        # "spearman": spearman,
        "l1": l1,
    }
