from typing import Tuple, Dict, List, Optional
import torch
import numpy as np
import faiss
def get_knn_faiss(
    query: torch.Tensor,
    target: torch.Tensor,
    k: int = 10
) -> Tuple[torch.Tensor, torch.Tensor]:
    query_np = query.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    faiss.normalize_L2(query_np)
    faiss.normalize_L2(target_np)
    index = faiss.IndexFlatIP(target_np.shape[1])
    index.add(target_np)
    scores, indices = index.search(query_np, k)
    return torch.from_numpy(scores).to(query.device), torch.from_numpy(indices).to(query.device)
def jaccard_similarity(indices_1: np.ndarray, indices_2: np.ndarray) -> float:
    set_1, set_2 = set(indices_1.flatten()), set(indices_2.flatten())
    intersection = len(set_1.intersection(set_2))
    return intersection / len(set_1.union(set_2)) if set_1 or set_2 else 0.0
def evaluate_embedding_alignment(
    emb_1: torch.Tensor,
    emb_2: torch.Tensor,
    emb_1_transformed: torch.Tensor,
    k_values: List[int] = [5, 10, 20, 50, 100],
    query_indices: Optional[torch.Tensor] = None
) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float], Dict[int, float], Dict[int, float]]:
    emb_1 = torch.nn.functional.normalize(emb_1, p=2, dim=1)
    emb_2 = torch.nn.functional.normalize(emb_2, p=2, dim=1)
    emb_1_transformed = torch.nn.functional.normalize(emb_1_transformed, p=2, dim=1)
    query_indices = query_indices or torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    original_sims = {}
    transformed_sims = {}
    emb_1_k2score = {}
    emb_2_k2score = {}
    trans_k2score = {}
    for k in k_values:
        emb_1_score, emb_1_idx = get_knn_faiss(emb_1[query_indices], emb_1, k)
        emb_2_score, emb_2_idx = get_knn_faiss(emb_2[query_indices], emb_2, k)
        trans_score, trans_idx = get_knn_faiss(emb_2[query_indices], emb_1_transformed, k)
        original_sims[k] = jaccard_similarity(emb_1_idx, emb_2_idx)
        transformed_sims[k] = jaccard_similarity(trans_idx, emb_2_idx)
        emb_1_k2score[k] = emb_1_score
        emb_2_k2score[k] = emb_2_score
        trans_k2score[k] = trans_score
    return original_sims, transformed_sims, emb_1_k2score, emb_2_k2score, trans_k2score
