import numpy as np
import torch
import faiss
from typing import List, Dict, Tuple
from tqdm import tqdm
from loguru import logger
def get_retrieval_list(query_emb_2: np.ndarray, corpus_emb_2: np.ndarray, top_k: int = 1000, metric: str = "l2") -> np.ndarray:
    if metric == "l2":
        index = faiss.IndexFlatL2(corpus_emb_2.shape[1])
    elif metric == "cosine":
        index = faiss.IndexFlatIP(corpus_emb_2.shape[1])
    index.add(corpus_emb_2.astype(np.float32))
    distances, indices = index.search(query_emb_2, top_k)
    return indices
def cal_max_distances(
    retrieval_list_1: np.ndarray,
    retrieval_list_2: np.ndarray,
    corpus_emb_1: np.ndarray,
    corpus_emb_2: np.ndarray,
    query_emb_1: np.ndarray,
    query_emb_2: np.ndarray,
    p_index_list: List[int],
    k_list: List[int]
) -> np.ndarray:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if not isinstance(corpus_emb_1, torch.Tensor):
        corpus_emb_1 = torch.tensor(corpus_emb_1, device=device)
    if not isinstance(corpus_emb_2, torch.Tensor):
        corpus_emb_2 = torch.tensor(corpus_emb_2, device=device)
    max_distance_matrix = np.zeros((len(k_list), len(p_index_list)))
    for i, (q_list_1, q_list_2) in tqdm(
        enumerate(zip(retrieval_list_1, retrieval_list_2)), 
        desc="Calculating max distances", 
        total=len(retrieval_list_1)
    ):
        q_object_1 = corpus_emb_2[q_list_1]
        q_object_2 = corpus_emb_2[q_list_2]
        distance_a_to_b = torch.cdist(q_object_1, q_object_2)
        distance_b_to_a = torch.cdist(q_object_2, q_object_1)
        for j, k in enumerate(k_list):
            k_k_max_a_to_b_distance = distance_a_to_b[:k, :k].min(dim=1).values.max().item()
            k_k_max_b_to_a_distance = distance_b_to_a[:k, :k].min(dim=1).values.max().item()
            max_distance_matrix[j, i] = max(k_k_max_a_to_b_distance, k_k_max_b_to_a_distance)
    return max_distance_matrix
def merge_and_evaluate_embeddings(
    corpus_emb_1: np.ndarray,
    corpus_emb_2: np.ndarray,
    corpus_emb_1_transformed: np.ndarray,
    query_emb_2: np.ndarray,
    p_index_list: List[int],
    d0: np.ndarray,
    d1: np.ndarray,
    d2: np.ndarray,
) -> Dict[str, float]:
    p_index_list_1 = d1
    merged_embeddings = {
        "baseline_1": corpus_emb_2.copy(),
        "baseline_2": corpus_emb_2.copy(),
        "transformed": corpus_emb_2.copy()
    }
    merged_embeddings["baseline_2"][p_index_list_1] = corpus_emb_1[p_index_list_1]
    merged_embeddings["transformed"][p_index_list_1] = corpus_emb_1_transformed[p_index_list_1]
    recalls = {}
    for top_k in [10, 50, 100, 500, 1000]:
        retrieval_results = {
            method: get_retrieval_list(query_emb_2, embeddings, top_k=top_k)
            for method, embeddings in merged_embeddings.items()
        }
        D1_hit_counts = {method: 0 for method in merged_embeddings.keys()}
        D2_hit_counts = {method: 0 for method in merged_embeddings.keys()}
        hit_counts = {method: 0 for method in merged_embeddings.keys()}
        for i, positive_index in enumerate(p_index_list):
            for method, results in retrieval_results.items():
                if positive_index in results[i]:
                    hit_counts[method] += 1
                    if positive_index in d1:
                        D1_hit_counts[method] += 1
                    if positive_index in d2:
                        D2_hit_counts[method] += 1
        all_D1_pos = len([results[i] for i in range(len(p_index_list)) if p_index_list[i] in d1])
        all_D2_pos = len([results[i] for i in range(len(p_index_list)) if p_index_list[i] in d2])
        recalls.update({method+f"_d1@{top_k}": count / all_D1_pos for method, count in D1_hit_counts.items()})
        recalls.update({method+f"_d2@{top_k}": count / all_D2_pos for method, count in D2_hit_counts.items()})
        recalls.update({method+f"@{top_k}": count / len(p_index_list) for method, count in hit_counts.items()})
    return recalls
