import numpy as np
import faiss
from typing import List, Dict, Tuple
from tqdm import tqdm
import torch
from loguru import logger
def get_retrieval_list(query_emb_2: np.ndarray, corpus_emb_2: np.ndarray, top_k: int = 1000, metric: str = "l2") -> np.ndarray:
    if metric == "l2":
        index = faiss.IndexFlatL2(corpus_emb_2.shape[1])
    elif metric == "cosine":
        index = faiss.IndexFlatIP(corpus_emb_2.shape[1])
    else:
        raise ValueError(f"Unsupported metric: {metric}")
    index.add(corpus_emb_2.astype(np.float32))
    distances, indices = index.search(query_emb_2.astype(np.float32), top_k)
    return indices
def calculate_baseline_recall(retrieval_list: np.ndarray, p_index_list: List[int], topk: int = 100) -> float:
    recall = 0
    for i in range(len(p_index_list)):
        if p_index_list[i] in retrieval_list[i][:topk]:
            recall += 1
    return recall / len(p_index_list)
def calculate_extended_recall(
    corpus_emb_1: np.ndarray, 
    corpus_emb_2: np.ndarray, 
    corpus_emb_1_transformed: np.ndarray, 
    query_emb_1: np.ndarray, 
    query_emb_2: np.ndarray, 
    p_index_list: List[int], 
    d0: np.ndarray, 
    d1: np.ndarray, 
    d2: np.ndarray, 
    k_to_max_distances: np.ndarray = None, 
    k_list: List[int] = [10, 100, 1000]
) -> Dict[str, float]:
    extended_recall_dict = {}
    merged_corpus_emb = corpus_emb_2.copy()
    merged_corpus_emb[d1] = corpus_emb_1_transformed[d1]
    retrieval_list_ours = get_retrieval_list(query_emb_2, merged_corpus_emb, k_list[-1])
    for i_th, k in enumerate(k_list):
        r_ours = retrieval_list_ours[:, :k]
        avg_recall = _cal_single_extended_recall(
            r_ours, corpus_emb_1, corpus_emb_2, corpus_emb_1_transformed, 
            query_emb_1, query_emb_2, p_index_list, k_to_max_distances[i_th]
        )
        logger.info(f"Top-{k}: Average Extended Recall: {avg_recall:.4f}")
        extended_recall_dict[f"extended_recall@{k}"] = avg_recall
    return extended_recall_dict
def cal_rank_recall(
    corpus_emb_1: np.ndarray, 
    corpus_emb_2: np.ndarray, 
    corpus_emb_1_transformed: np.ndarray, 
    query_emb_1: np.ndarray, 
    query_emb_2: np.ndarray, 
    p_index_list: List[int], 
    k_list: List[int]
) -> float:
    baseline_retrieval_list = get_retrieval_list(query_emb_2, corpus_emb_2, len(corpus_emb_2))
    merged_retrieval_list = get_retrieval_list(query_emb_2, corpus_emb_1_transformed, len(corpus_emb_1_transformed))
    p_index_rank_list = []
    for i, p_index in enumerate(p_index_list):
        i_th_retrieval_list_baseline = baseline_retrieval_list[i]
        i_th_rank_baseline = np.where(i_th_retrieval_list_baseline == p_index)[0][0]
        i_th_retrieval_list_transformed = merged_retrieval_list[i]
        i_th_rank_transformed = np.where(i_th_retrieval_list_transformed == p_index)[0][0]
        p_index_rank_list.append({
            "baseline_rank": i_th_rank_baseline,
            "transformed_rank": i_th_rank_transformed
        })
    ndcg_list = []
    for i, p_index_rank in enumerate(p_index_rank_list):
        baseline_rank = p_index_rank["baseline_rank"]
        transformed_rank = p_index_rank["transformed_rank"]
        if transformed_rank <= baseline_rank: 
            ndcg_list.append(1)
        else:
            ndcg_list.append(1/np.log2((transformed_rank - baseline_rank + 1)))
    ndcg = sum(ndcg_list) / len(ndcg_list)
    return ndcg
def _cal_single_extended_recall(
    retrieval_list_ours: np.ndarray,
    corpus_emb_1: np.ndarray, 
    corpus_emb_2: np.ndarray, 
    corpus_emb_1_transformed: np.ndarray, 
    query_emb_1: np.ndarray, 
    query_emb_2: np.ndarray, 
    p_index_list: List[int], 
    max_distances: np.ndarray
) -> float:
    recall_list = []
    hit_counts = {
        "in_ours": 0,
        "transformed_max_dist": 0,
        "miss": 0
    }
    for i, (q_list_ours, p_index) in enumerate(zip(retrieval_list_ours, p_index_list)):
        transformed_combined_query_embeddings = corpus_emb_2[q_list_ours]
        transformed_positive_embedding = corpus_emb_2[p_index]
        p_to_transformed_max_dist = min([
            np.linalg.norm(transformed_positive_embedding - a) 
            for a in transformed_combined_query_embeddings
        ])
        if p_index in q_list_ours:
            recall_list.append(1)
            hit_counts["in_ours"] += 1
        elif p_to_transformed_max_dist <= max_distances[i]:
            recall_list.append(1)
            hit_counts["transformed_max_dist"] += 1
        else:
            recall_list.append(0)
            hit_counts["miss"] += 1
    avg_recall = sum(recall_list) / len(recall_list)
    return avg_recall
