from numpy import ndarray
import numpy as np
import torch
import os
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from loguru import logger
from tqdm import tqdm
import gc
if TYPE_CHECKING:
    from usearch.index import Index
class DistanceCalculator:
    def __init__(self, device: Optional[torch.device] = None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    def compute_batch_distances(self, queries: torch.Tensor, corpus: torch.Tensor, metric: str) -> torch.Tensor:
        if metric == "cosine":
            queries_norm = torch.nn.functional.normalize(queries, p=2, dim=1)
            corpus_norm = torch.nn.functional.normalize(corpus, p=2, dim=1)
            similarities = torch.mm(queries_norm, corpus_norm.t())
            return 1 - similarities
        elif metric == "euclidean":
            queries_norm_sq = (queries * queries).sum(dim=1, keepdim=True)
            corpus_norm_sq = (corpus * corpus).sum(dim=1, keepdim=False)
            dot_product = torch.mm(queries, corpus.t())
            distances = queries_norm_sq + corpus_norm_sq.unsqueeze(0) - 2 * dot_product
            return torch.sqrt(torch.clamp(distances, min=1e-8))
        else:
            raise ValueError(f"Unsupported distance metric: {metric}")
    def get_top_k_indices(self, queries: torch.Tensor, corpus: np.ndarray, 
                         top_k: int, metric: str, corpus_batch_size: int = 10000) -> torch.Tensor:
        num_queries = queries.shape[0]
        corpus_size = corpus.shape[0]
        top_k_distances = torch.full((num_queries, top_k), float('inf'), device=self.device)
        top_k_indices = torch.full((num_queries, top_k), -1, device=self.device, dtype=torch.long)
        corpus_batches = list(range(0, corpus_size, corpus_batch_size))
        for start_idx in tqdm(corpus_batches, desc=f"Computing distances for {num_queries} queries"):
            end_idx = min(start_idx + corpus_batch_size, corpus_size)
            corpus_slice = corpus[start_idx:end_idx]
            corpus_batch = torch.from_numpy(corpus_slice).to(device=self.device, dtype=torch.float32)
            batch_distances = self.compute_batch_distances(queries, corpus_batch, metric)
            batch_indices = torch.arange(start_idx, end_idx, device=self.device).unsqueeze(0).expand(num_queries, -1)
            combined_distances = torch.cat([top_k_distances, batch_distances], dim=1)
            combined_indices = torch.cat([top_k_indices, batch_indices], dim=1)
            new_top_distances, sort_indices = torch.topk(combined_distances, k=top_k, dim=1, largest=False, sorted=True)
            top_k_distances = new_top_distances
            top_k_indices = torch.gather(combined_indices, 1, sort_indices)
            del corpus_batch, batch_distances, batch_indices, combined_distances, combined_indices
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
        return top_k_indices
class CorpusBuilder:
    @staticmethod
    def align_dimensions_lazy(emb_slice: np.ndarray, target_dim: int) -> np.ndarray:
        if emb_slice.shape[1] == target_dim:
            return emb_slice
        elif emb_slice.shape[1] < target_dim:
            result = np.zeros((emb_slice.shape[0], target_dim), dtype=np.float32)
            result[:, :emb_slice.shape[1]] = emb_slice
            return result
        else:
            return emb_slice[:, :target_dim]
    def build_corpus_lazy(self, method: str, corpus_emb_2: np.ndarray, 
                         corpus_emb_1: np.ndarray, corpus_emb_1_transformed: np.ndarray,
                         non_ref_indices: set, chunk_size: int = 500000) -> np.ndarray:
        target_dim = corpus_emb_2.shape[1]
        corpus_size = corpus_emb_2.shape[0]
        if method == "target_only":
            return corpus_emb_2
        result_corpus = np.empty((corpus_size, target_dim), dtype=np.float32)
        chunks = list(range(0, corpus_size, chunk_size))
        for start_idx in tqdm(chunks, desc=f"Building {method} corpus"):
            end_idx = min(start_idx + chunk_size, corpus_size)
            target_chunk = corpus_emb_2[start_idx:end_idx]
            result_corpus[start_idx:end_idx] = target_chunk
            chunk_non_ref = [idx for idx in non_ref_indices if start_idx <= idx < end_idx]
            result_corpus[chunk_non_ref] = self.align_dimensions_lazy(corpus_emb_1_transformed[chunk_non_ref], target_dim)
            if chunk_size > 10000:
                gc.collect()
        return result_corpus
    def build_corpus(self, method: str, corpus_emb_2: np.ndarray, 
                    corpus_emb_1: np.ndarray, corpus_emb_1_transformed: np.ndarray,
                    non_ref_indices: set) -> np.ndarray:
        is_mmap = (hasattr(corpus_emb_2, 'filename') and corpus_emb_2.filename) or \
                  isinstance(corpus_emb_2, np.memmap)
        if is_mmap or corpus_emb_2.shape[0] > 100000:
            logger.info(f"Using lazy loading for {method} (detected mmap or large array)")
            return self.build_corpus_lazy(method, corpus_emb_2, corpus_emb_1, 
                                        corpus_emb_1_transformed, non_ref_indices)
        else:
            return self.build_corpus_fast(method, corpus_emb_2, corpus_emb_1, 
                                        corpus_emb_1_transformed, non_ref_indices)
    def build_corpus_fast(self, method: str, corpus_emb_2: np.ndarray, 
                         corpus_emb_1: np.ndarray, corpus_emb_1_transformed: np.ndarray,
                         non_ref_indices: set) -> np.ndarray:
        target_dim = corpus_emb_2.shape[1]
        if method == "target_only":
            return corpus_emb_2
        result_corpus = corpus_emb_2.astype(np.float32)
        for idx in non_ref_indices:
            if method == "direct_concat":
                source_emb = corpus_emb_1[idx:idx+1]
                result_corpus[idx] = self.align_dimensions_lazy(source_emb, target_dim)[0]
            elif method == "our_method":
                transformed_emb = corpus_emb_1_transformed[idx:idx+1]
                result_corpus[idx] = self.align_dimensions_lazy(transformed_emb, target_dim)[0]
        return result_corpus
class Evaluator:
    def __init__(self, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, 
                 query_emb_1: np.ndarray, query_emb_2: np.ndarray,
                 corpus_emb_1_transformed: np.ndarray, query_emb_1_transformed: np.ndarray,
                 p_index_list: Dict[int, List[int]], d0: np.ndarray, d1: np.ndarray, d2: np.ndarray,
                 k_list: List[int] = [10, 50, 100, 500, 1000]):
        self.corpus_emb_1 = corpus_emb_1
        self.corpus_emb_2 = corpus_emb_2
        self.query_emb_1 = query_emb_1
        self.query_emb_2 = query_emb_2
        self.corpus_emb_1_transformed = corpus_emb_1_transformed
        self.query_emb_1_transformed = query_emb_1_transformed
        self.p_index_list = p_index_list
        self.d0 = set(d0) if d0 is not None else set()
        self.d1 = set(d1)
        self.d2 = set(d2)
        self.k_list = k_list
        self.distance_calc = DistanceCalculator()
        self.corpus_builder = CorpusBuilder()
    def _calculate_jaccard_similarity(self, results_1: np.ndarray, results_2: np.ndarray) -> float:
        jaccard_scores = []
        for i in range(results_1.shape[0]):
            set_1 = set(results_1[i].tolist())
            set_2 = set(results_2[i].tolist())
            intersection = len(set_1 & set_2)
            union = len(set_1 | set_2)
            jaccard_scores.append(intersection / union if union > 0 else 1.0)
        return np.mean(jaccard_scores)
    def _evaluate_mapping_consistency(self) -> Dict[str, float]:
        results = {}
        max_k = max(self.k_list)
        original_results = self.distance_calc.get_top_k_indices(
            torch.from_numpy(self.query_emb_1.astype(np.float32)).to(self.distance_calc.device),
            self.corpus_emb_1, max_k, "cosine"
        ).cpu().numpy()
        transformed_results = self.distance_calc.get_top_k_indices(
            torch.from_numpy(self.query_emb_1_transformed.astype(np.float32)).to(self.distance_calc.device),
            self.corpus_emb_1_transformed, max_k, "cosine"
        ).cpu().numpy()
        for top_k in self.k_list:
            original_sliced = original_results[:, :top_k]
            transformed_sliced = transformed_results[:, :top_k]
            js = self._calculate_jaccard_similarity(original_sliced, transformed_sliced)
            results[f"mapping_consistency@{top_k}"] = js
        return results
    def _evaluate_cross_domain_retrieval(self, distance_metric: str = "cosine", 
                                       batch_size: int = 1000, corpus_batch_size: int = 10000) -> Dict[str, float]:
        logger.info(f"Evaluating with {distance_metric} distance on {self.distance_calc.device}")
        all_d1_pos = sum(1 for pos_list in self.p_index_list.values() 
                        if pos_list and len(pos_list) > 0 and pos_list[0] in self.d1)
        all_d2_pos = sum(1 for pos_list in self.p_index_list.values() 
                        if pos_list and len(pos_list) > 0 and pos_list[0] in self.d2)
        methods = ["target_only", "direct_concat", "our_method"]
        recalls = {}
        query_tensor = torch.from_numpy(self.query_emb_2.astype(np.float32)).to(self.distance_calc.device)
        for method in tqdm(methods, desc="Processing methods"):
            logger.info(f"Processing method: {method}")
            method_corpus = self.corpus_builder.build_corpus(
                method, self.corpus_emb_2, self.corpus_emb_1, 
                self.corpus_emb_1_transformed, self.d1
            )
            max_k = max(self.k_list)
            num_queries = query_tensor.shape[0]
            all_batch_results = []
            query_batches = list(range(0, num_queries, batch_size))
            for start_idx in tqdm(query_batches, desc=f"Processing {method} queries"):
                end_idx = min(start_idx + batch_size, num_queries)
                query_batch = query_tensor[start_idx:end_idx]
                batch_results = self.distance_calc.get_top_k_indices(
                    query_batch, method_corpus, max_k, distance_metric, corpus_batch_size
                ).cpu().numpy()
                all_batch_results.append(batch_results)
                if self.distance_calc.device.type == 'cuda':
                    torch.cuda.empty_cache()
            all_results = np.concatenate(all_batch_results, axis=0)
            for top_k in self.k_list:
                hit_count = d1_hit_count = d2_hit_count = 0
                for query_idx in range(num_queries):
                    if query_idx not in self.p_index_list:
                        continue
                    pos_list = self.p_index_list[query_idx]
                    if not pos_list or len(pos_list) == 0:
                        continue
                    positive_idx = int(pos_list[0])
                    retrieval_results = all_results[query_idx, :top_k]
                    if positive_idx in retrieval_results:
                        hit_count += 1
                        if positive_idx in self.d1:
                            d1_hit_count += 1
                        if positive_idx in self.d2:
                            d2_hit_count += 1
                recalls[f"{method}@{top_k}"] = hit_count / len(self.p_index_list)
                if all_d1_pos > 0:
                    recalls[f"{method}_D1@{top_k}"] = d1_hit_count / all_d1_pos
                if all_d2_pos > 0:
                    recalls[f"{method}_D2@{top_k}"] = d2_hit_count / all_d2_pos
            del method_corpus
            gc.collect()
            if self.distance_calc.device.type == 'cuda':
                torch.cuda.empty_cache()
        return recalls
    def _calculate_embedding_distances(self, batch_size: int = 10000) -> Dict[str, float]:
        logger.info("Calculating embedding distances...")
        mapped_src_emb_tensor = torch.from_numpy(self.corpus_emb_1_transformed.astype(np.float32))
        tgt_emb_tensor = torch.from_numpy(self.corpus_emb_2.astype(np.float32))
        logger.info(f"Corpus embeddings - mapped_src: {mapped_src_emb_tensor.shape}, tgt: {tgt_emb_tensor.shape}")
        if mapped_src_emb_tensor.shape[1] == tgt_emb_tensor.shape[1]:
            corpus_cosine_dist = self._calculate_cosine_distance_batched(mapped_src_emb_tensor, tgt_emb_tensor, batch_size)
            corpus_euclidean_dist = self._calculate_euclidean_distance_batched(mapped_src_emb_tensor, tgt_emb_tensor, batch_size)
        else:
            logger.warning(f"Corpus embedding dimension mismatch: {mapped_src_emb_tensor.shape[1]} vs {tgt_emb_tensor.shape[1]}. Skipping corpus distance calculation.")
            raise ValueError(f"Corpus embedding dimension mismatch: {mapped_src_emb_tensor.shape[1]} vs {tgt_emb_tensor.shape[1]}")
        return {
            "corpus_cosine_distance": corpus_cosine_dist,
            "corpus_euclidean_distance": corpus_euclidean_dist,
        }
    def _calculate_cosine_distance_batched(self, emb1: torch.Tensor, emb2: torch.Tensor, batch_size: int) -> float:
        logger.info(f"Calculating cosine distance in batches of {batch_size}")
        all_distances = []
        num_samples = emb1.shape[0]
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            emb1_norm = torch.nn.functional.normalize(emb1[i:end_idx], p=2, dim=1)
            emb2_norm = torch.nn.functional.normalize(emb2[i:end_idx], p=2, dim=1)
            cosine_dist = 1 - torch.nn.functional.cosine_similarity(emb1_norm, emb2_norm, dim=1)
            all_distances.append(cosine_dist)
            if self.distance_calc.device.type == 'cuda':
                torch.cuda.empty_cache()
        all_distances = torch.cat(all_distances, dim=0)
        if torch.isnan(all_distances).all():
            return -1.0
        return torch.nanmean(all_distances).item()
    def _calculate_euclidean_distance_batched(self, emb1: torch.Tensor, emb2: torch.Tensor, batch_size: int) -> float:
        logger.info(f"Calculating euclidean distance in batches of {batch_size}")
        all_distances = []
        num_samples = emb1.shape[0]
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            euclidean_dist = torch.cdist(emb1[i:end_idx], emb2[i:end_idx], p=2).diag()
            all_distances.append(euclidean_dist)
            if self.distance_calc.device.type == 'cuda':
                torch.cuda.empty_cache()
        all_distances = torch.cat(all_distances, dim=0)
        if torch.isnan(all_distances).all():
            return -1.0
        return torch.nanmean(all_distances).item()
    def _calculate_cosine_distance(self, emb1: torch.Tensor, emb2: torch.Tensor) -> float:
        cosine_dist = 1 - torch.nn.functional.cosine_similarity(
            torch.nn.functional.normalize(emb1, p=2, dim=1),
            torch.nn.functional.normalize(emb2, p=2, dim=1),
            dim=1
        )
        if torch.isnan(cosine_dist).all():
            return -1.0
        return torch.nanmean(cosine_dist).item()
    def _calculate_euclidean_distance(self, emb1: torch.Tensor, emb2: torch.Tensor) -> float:
        euclidean_dist = torch.cdist(emb1, emb2, p=2).diag()
        if torch.isnan(euclidean_dist).all():
            return -1.0
        return torch.nanmean(euclidean_dist).item()
    def evaluate(self, use_low_memory: bool = True, distance_metric: str = "cosine", 
                batch_size: int = 1000, corpus_batch_size: int = 10000) -> Dict[str, float]:
        logger.info("Starting evaluation...")
        cross_domain_results = self._evaluate_cross_domain_retrieval(
            distance_metric, batch_size, corpus_batch_size
        )
        consistency_results = self._evaluate_mapping_consistency()
        distance_results = self._calculate_embedding_distances(batch_size=corpus_batch_size)
        all_results = {**cross_domain_results, **consistency_results, **distance_results}
        logger.info(f"Evaluation completed with {len(all_results)} metrics")
        return all_results
class USearchEvaluator:
    def __init__(
        self,
        corpus_emb_paths: List[str],
        query_emb_list: List[np.ndarray],
        q2a_list: List[Dict[int, List[int]]],
        k_list: List[int] = [10, 50, 100, 500, 1000],
        corpus_expert_ids_paths: Optional[List[str]] = None
    ):
        try:
            from usearch.index import Index
            self.Index = Index
        except ImportError:
            logger.error("usearch library not found. Please install with: pip install usearch")
            raise ImportError("usearch is required for USearchEvaluator")
        self.corpus_emb_paths = corpus_emb_paths
        self.query_emb_list = query_emb_list
        self.q2a_list = q2a_list
        self.k_list = k_list
        self.corpus_expert_ids_paths = corpus_expert_ids_paths
        self.corpus_expert_ids_list: List[Optional[np.ndarray]] = []
        if corpus_expert_ids_paths:
            for path in corpus_expert_ids_paths:
                if path and os.path.exists(path):
                    expert_ids = np.load(path)
                    self.corpus_expert_ids_list.append(expert_ids)
                    logger.info(f"Loaded expert IDs from {path}: shape={expert_ids.shape}")
                else:
                    self.corpus_expert_ids_list.append(None)
        else:
            self.corpus_expert_ids_list = [None] * len(corpus_emb_paths)
        self.indices: List = []
        self._build_indices()
    def _build_indices(self, batch_size: int = 100000) -> None:
        logger.info(f"Building usearch indices for {len(self.corpus_emb_paths)} corpus files")
        for corpus_path in tqdm(self.corpus_emb_paths, desc="Building indices"):
            with open(corpus_path, 'rb') as f:
                np.lib.format.read_magic(f)
                shape, _, dtype_from_file = np.lib.format.read_array_header_1_0(f)
                data_offset = f.tell()
            n_samples, embedding_dim = shape
            actual_dtype = dtype_from_file if dtype_from_file is not None else np.float32
            corpus_memmap = np.memmap(corpus_path, dtype=actual_dtype, mode='r', shape=shape, offset=data_offset)
            index = self.Index(ndim=embedding_dim, metric='cos', dtype='f32')
            for start_idx in range(0, n_samples, batch_size):
                end_idx = min(start_idx + batch_size, n_samples)
                batch = corpus_memmap[start_idx:end_idx].astype(np.float32)
                batch_ids = np.arange(start_idx, end_idx, dtype=np.uint32)
                index.add(batch_ids, batch)
            del corpus_memmap
            self.indices.append(index)
            logger.info(f"✓ Built index for {corpus_path}: {n_samples} vectors, dim={embedding_dim}")
    def _search_usearch(self, index, queries: np.ndarray, k: int) -> np.ndarray:
        max_k = max(self.k_list)
        results = []
        for i, query in enumerate(queries):
            query_reshaped = query.reshape(1, -1).astype(np.float32)
            matches = index.search(query_reshaped, max_k, exact=True)
            keys = [int(k) for k in matches.keys][:max_k]
            results.append(keys)
        return np.array(results, dtype=np.int64)
    def evaluate(self) -> Dict[str, float]:
        logger.info("Starting USearch evaluation...")
        results = {}
        for idx, (query_emb, q2a, index) in enumerate(zip(self.query_emb_list, self.q2a_list, self.indices)):
            max_k = max(self.k_list)
            search_results = self._search_usearch(index, query_emb, max_k)
            corpus_expert_ids = self.corpus_expert_ids_list[idx]
            for k in self.k_list:
                hits = 0
                expert_diversity_list = []
                for query_idx, positive_indices in q2a.items():
                    if not positive_indices:
                        continue
                    top_k = search_results[query_idx, :k]
                    if int(positive_indices[0]) in top_k:
                        hits += 1
                    if corpus_expert_ids is not None:
                        try:
                            top_k_expert_ids = corpus_expert_ids[top_k]
                            num_unique = len(np.unique(top_k_expert_ids))
                            expert_diversity_list.append(num_unique)
                        except IndexError:
                            pass
                results[f"dataset_{idx}_recall@{k}"] = hits / len(q2a)
                if expert_diversity_list:
                    results[f"dataset_{idx}_avg_expert_diversity@{k}"] = np.mean(expert_diversity_list)
                    results[f"dataset_{idx}_max_expert_diversity@{k}"] = np.max(expert_diversity_list)
                    results[f"dataset_{idx}_min_expert_diversity@{k}"] = np.min(expert_diversity_list)
                    multi_expert_queries = sum(1 for d in expert_diversity_list if d > 1)
                    results[f"dataset_{idx}_multi_expert_ratio@{k}"] = multi_expert_queries / len(expert_diversity_list)
            logger.info(f"✓ Dataset {idx}: {len(q2a)} queries evaluated")
            if corpus_expert_ids is not None:
                logger.info(f"  ✓ Expert diversity metrics calculated")
        return results
