import pandas as pd
import numpy as np
import typer
from tqdm import tqdm, trange
from typer import Typer
from pathlib import Path
from loguru import logger
from typing import Optional, Tuple
import torch
from tqdm import tqdm
import gc
import json
import dotenv
import os
import psutil
from ..embeddings.memmap_dataset import get_embedding_memmap, MemmapEmbeddingDataset
from ..noco_util.util import NocoClient
from pymilvus import MilvusClient, DataType
dotenv.load_dotenv()
cal_dis_app = Typer()
def sanitize(a: np.ndarray, fill=0.0):
    a = np.asarray(a, dtype=np.float32)
    bad = ~np.isfinite(a)
    if bad.any():
        a = a.copy()
        a[bad] = fill
    return a
def random_sample_vectors(embeddings: np.ndarray, indices: Optional[np.ndarray], sample_size: Optional[int], seed: int = 42) -> Tuple[np.ndarray, np.ndarray]:
    if sample_size is None:
        if indices is not None:
            return embeddings[indices], indices
        else:
            return embeddings, np.arange(len(embeddings))
    np.random.seed(seed)
    if indices is not None:
        available_indices = indices
        if len(available_indices) <= sample_size:
            logger.info(f"Requested {sample_size} samples but only {len(available_indices)} available, using all")
            return embeddings[available_indices], available_indices
        else:
            sampled_indices = np.random.choice(available_indices, size=sample_size, replace=False)
            sampled_indices = np.sort(sampled_indices)
            logger.info(f"Randomly sampled {sample_size} queries from {len(available_indices)} available")
            return embeddings[sampled_indices], sampled_indices
    else:
        if len(embeddings) <= sample_size:
            logger.info(f"Requested {sample_size} samples but only {len(embeddings)} available, using all")
            return embeddings, np.arange(len(embeddings))
        else:
            sampled_indices = np.random.choice(len(embeddings), size=sample_size, replace=False)
            sampled_indices = np.sort(sampled_indices)
            logger.info(f"Randomly sampled {sample_size} vectors from {len(embeddings)} available")
            return embeddings[sampled_indices], sampled_indices
def get_device(device: Optional[str] = None) -> torch.device:
    if device is None:
        if torch.cuda.is_available():
            device = "cuda"
            logger.info(f"Using GPU: {torch.cuda.get_device_name()}")
            logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        else:
            device = "cpu"
            logger.info("Using CPU (CUDA not available)")
    else:
        if device == "cuda" and not torch.cuda.is_available():
            logger.warning("CUDA requested but not available, falling back to CPU")
            device = "cpu"
        logger.info(f"Using specified device: {device}")
    return torch.device(device)
def log_memory_usage(stage: str) -> float:
    try:
        process = psutil.Process()
        memory_info = process.memory_info()
        memory_mb = memory_info.rss / 1024 / 1024
        logger.info(f"Memory usage at {stage}: {memory_mb:.1f} MB")
        return memory_mb
    except Exception as e:
        logger.warning(f"Could not get memory usage: {e}")
        return 0.0
def get_optimal_batch_sizes(query_embeddings: np.ndarray, corpus_embeddings: np.ndarray, 
                           d0_index: np.ndarray, query_index: Optional[np.ndarray], top_k: int, 
                           initial_query_batch_size: int, initial_corpus_batch_size: int, ratio: float = 0.5) -> Tuple[int, int]:
    try:
        available_memory = psutil.virtual_memory().available / 1024 / 1024
        logger.info(f"Available memory: {available_memory:.1f} MB")
        query_dim = query_embeddings.shape[1]
        corpus_dim = corpus_embeddings.shape[1]
        actual_query_count = len(query_index) if query_index is not None else len(query_embeddings)
        memory_per_similarity = (query_dim + corpus_dim) * 4 / 1024 / 1024
        memory_per_batch = initial_query_batch_size * initial_corpus_batch_size * memory_per_similarity
        effective_query_batch_size = min(initial_query_batch_size, actual_query_count)
        if memory_per_batch > available_memory * ratio:
            reduction_factor = (available_memory * ratio) / memory_per_batch
            new_query_batch_size = max(1, int(effective_query_batch_size * reduction_factor ** 0.5))
            new_corpus_batch_size = max(1, int(initial_corpus_batch_size * reduction_factor ** 0.5))
            new_query_batch_size = min(new_query_batch_size, actual_query_count)
            logger.info(f"Reducing batch sizes due to memory constraints:")
            logger.info(f"  Query batch size: {initial_query_batch_size} -> {new_query_batch_size}")
            logger.info(f"  Corpus batch size: {initial_corpus_batch_size} -> {new_corpus_batch_size}")
            logger.info(f"  Actual queries to process: {actual_query_count}")
            return new_query_batch_size, new_corpus_batch_size
        else:
            logger.info(f"Using initial batch sizes: query={effective_query_batch_size}, corpus={initial_corpus_batch_size}")
            logger.info(f"Actual queries to process: {actual_query_count}")
            return effective_query_batch_size, initial_corpus_batch_size
    except Exception as e:
        logger.warning(f"Could not calculate optimal batch sizes: {e}")
        return initial_query_batch_size, initial_corpus_batch_size
def calculate_cosine_similarities_simple(query_embeddings: np.ndarray, corpus_embeddings: np.ndarray, top_k: int, device: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
    torch_device = get_device(device)
    cosine_similarities = np.dot(query_embeddings, corpus_embeddings.T)
    top_k_indices = np.argsort(-cosine_similarities)[:, :top_k]
    top_k_similarities = cosine_similarities[np.arange(len(query_embeddings))[:, None], top_k_indices]
    return top_k_indices, top_k_similarities
def calculate_cosine_similarities_manual(query_embeddings: np.ndarray, corpus_embeddings: np.ndarray, d0_index: np.ndarray, query_index: Optional[np.ndarray], top_k: int, query_batch_size: int = 1000, corpus_batch_size: int = 1000, device: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
    torch_device = get_device(device)
    if query_index is not None:
        actual_query_embeddings = query_embeddings[query_index]
        n_queries = len(query_index)
        logger.info(f"Using subset of queries: {n_queries} queries from {len(query_embeddings)} total")
    else:
        actual_query_embeddings = query_embeddings
        n_queries = len(query_embeddings)
        logger.info(f"Using all queries: {n_queries} queries")
    logger.info("Calculating cosine similarities with PyTorch acceleration...")
    logger.info(f"Processing {n_queries} queries in batches of {query_batch_size}")
    logger.info(f"Processing {len(d0_index)} D0 embeddings in batches of {corpus_batch_size}")
    logger.info(f"Memory-efficient streaming top-{top_k} calculation")
    logger.info(f"Using device: {torch_device}")
    log_memory_usage("start of calculation")
    n_d0 = len(d0_index)
    top_k_indices = np.zeros((n_queries, top_k), dtype=np.int32)
    top_k_similarities = np.full((n_queries, top_k), -np.inf, dtype=np.float32)
    log_memory_usage("after initialization")
    for query_batch_start in trange(0, n_queries, query_batch_size, desc="Processing query batches"):
        query_batch_end = min(query_batch_start + query_batch_size, n_queries)
        query_batch = actual_query_embeddings[query_batch_start:query_batch_end]
        query_batch_tensor = torch.from_numpy(query_batch).to(torch_device)
        query_norms = torch.norm(query_batch_tensor, dim=1, keepdim=True)
        query_batch_norm = query_batch_tensor / torch.where(query_norms == 0, torch.ones_like(query_norms), query_norms)
        query_batch_num = query_batch_start // query_batch_size + 1
        total_query_batches = (n_queries + query_batch_size - 1) // query_batch_size
        for corpus_batch_start in trange(0, n_d0, corpus_batch_size, desc=f"Query batch {query_batch_num}/{total_query_batches}"):
            corpus_batch_end = min(corpus_batch_start + corpus_batch_size, n_d0)
            batch_d0_indices = d0_index[corpus_batch_start:corpus_batch_end]
            batch_d0_embeddings = corpus_embeddings[batch_d0_indices]
            batch_d0_embeddings_tensor = torch.from_numpy(batch_d0_embeddings).to(torch_device)
            d0_norms = torch.norm(batch_d0_embeddings_tensor, dim=1, keepdim=True)
            batch_d0_embeddings_norm = batch_d0_embeddings_tensor / torch.where(d0_norms == 0, torch.ones_like(d0_norms), d0_norms)
            batch_similarities = torch.mm(query_batch_norm, batch_d0_embeddings_norm.T)
            batch_similarities_np = batch_similarities.cpu().numpy()
            if (query_batch_start + corpus_batch_start) % (query_batch_size * corpus_batch_size * 10) == 0:
                log_memory_usage(f"query batch {query_batch_start//query_batch_size + 1}, corpus batch {corpus_batch_start//corpus_batch_size + 1}")
                if torch_device.type == 'cuda':
                    logger.info(f"GPU Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
            for local_query_idx in range(len(query_batch)):
                global_query_idx = query_batch_start + local_query_idx
                query_similarities = batch_similarities_np[local_query_idx, :]
                min_similarity = top_k_similarities[global_query_idx, -1]
                better_mask = query_similarities > min_similarity
                if better_mask.any():
                    better_indices = np.where(better_mask)[0]
                    better_similarities = query_similarities[better_indices]
                    global_indices = batch_d0_indices[better_indices]
                    combined_indices = np.concatenate([top_k_indices[global_query_idx], global_indices])
                    combined_similarities = np.concatenate([top_k_similarities[global_query_idx], better_similarities])
                    sort_indices = np.argsort(-combined_similarities)[:top_k]
                    top_k_indices[global_query_idx] = combined_indices[sort_indices]
                    top_k_similarities[global_query_idx] = combined_similarities[sort_indices]
            if torch_device.type == 'cuda':
                del batch_d0_embeddings_tensor, batch_similarities
                torch.cuda.empty_cache()
        if torch_device.type == 'cuda':
            del query_batch_tensor, query_batch_norm
            torch.cuda.empty_cache()
    log_memory_usage("end of calculation")
    logger.info("Dual-direction batch processing completed successfully")
    return top_k_indices, top_k_similarities
def calculate_cosine_similarities_milvus(client: MilvusClient, collection_name: str, query_embeddings: np.ndarray, query_indices: np.ndarray,
                                       top_k: int, search_params: dict = None, batch_size: int = 100) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    logger.info(f"Calculating cosine similarities using Milvus with batch size {batch_size}...")
    if search_params is None:
        search_params = {
            "params": {"nprobe": 10}
        }
    if query_indices is not None:
        selected_query_embeddings = query_embeddings[query_indices]
        selected_query_indices = query_indices
    else:
        selected_query_embeddings = query_embeddings
        selected_query_indices = np.arange(len(query_embeddings))
    all_results = []
    num_queries = len(selected_query_embeddings)
    for batch_start in tqdm(range(0, num_queries, batch_size), desc=f"Processing {num_queries} queries"):
        batch_end = min(batch_start + batch_size, num_queries)
        batch_queries = selected_query_embeddings[batch_start:batch_end]
        logger.info(f"Processing batch {batch_start//batch_size + 1}/{(num_queries + batch_size - 1)//batch_size} "
                   f"(queries {batch_start+1}-{batch_end})")
        batch_query_vectors = [sanitize(query_emb).tolist() for query_emb in batch_queries]
        try:
            results = client.search(
                collection_name=collection_name,
                data=batch_query_vectors,
                anns_field="embedding",
                search_params=search_params,
                limit=top_k,
            )
            for i, query_results in enumerate(results):
                query_idx = batch_start + i
                top_k_indices = [hit['id'] for hit in query_results]
                top_k_similarities = [hit['distance'] for hit in query_results]
                top_k_entities = [hit.get('entity', {}) for hit in query_results]
                all_results.append({
                    'query_id': int(selected_query_indices[query_idx]),
                    'indices': top_k_indices,
                    'similarities': top_k_similarities,
                    'entities': top_k_entities
                })
        except Exception as e:
            logger.error(f"Failed to search batch {batch_start//batch_size + 1}: {str(e)}")
            for i in range(len(batch_queries)):
                query_idx = batch_start + i
                all_results.append({
                    'query_id': int(selected_query_indices[query_idx]),
                    'indices': [],
                    'similarities': [],
                    'entities': []
                })
    top_k_indices = np.array([r['indices'] for r in all_results])
    top_k_similarities = np.array([r['similarities'] for r in all_results])
    return top_k_indices, top_k_similarities
@cal_dis_app.command()
def cal_dis(
    test_dataset_name: str = typer.Option("scifact", help="Name of the test dataset"),
    train_dataset_name: str = typer.Option("scidocs", help="Name of the train dataset"),
    source_model: str = typer.Option("openai", "--source-model", "-s", help="Source embedding model name"),
    target_model: str = typer.Option("mistral", "--target-model", "-t", help="Target embedding model name"),
    embedding_path: str = typer.Option("./data/processed/embeddings/", "--embedding-path", "-e", help="Path to embeddings directory"),
    metric: str = typer.Option("cosine", "--metric", "-m", help="Distance metric (cosine or euclidean)"),
    top_k: int = typer.Option(5, "--top-k", "-k", help="Number of top results to retrieve"),
    output_dir: str = typer.Option("./output/", "--output-dir", "-o", help="Directory to save results"),
    query_batch_size: int = typer.Option(50000, "--query-batch-size", "-qb", help="Batch size for processing query embeddings"),
    corpus_batch_size: int = typer.Option(500000, "--corpus-batch-size", "-cb", help="Batch size for processing corpus embeddings"),
    sample_queries: Optional[int] = typer.Option(50000, "--sample-queries", "-sq", help="Number of queries to randomly sample (None for all)"),
    sample_corpus: Optional[int] = typer.Option(50000, "--sample-corpus", "-sc", help="Number of corpus vectors to randomly sample (None for all)"),
    part: str = typer.Option("d0", "--part", "-p", help="Part of the dataset to use (d0, d1, or query)"),
    device: Optional[str] = typer.Option(None, "--device", "-d", help="Device to use (cuda or cpu)"),
    use_milvus: bool = typer.Option(False, "--use-milvus", help="Use Milvus for similarity calculation"),
    milvus_host: str = typer.Option("localhost", "--milvus-host", help="Milvus server host"),
    milvus_port: str = typer.Option("19530", "--milvus-port", help="Milvus server port"),
    collection_name: str = typer.Option("embeddings", "--collection-name", help="Milvus collection name"),
    batch_size: int = typer.Option(100, "--batch-size", help="Batch size for Milvus queries")
):
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    train_dataset = MemmapEmbeddingDataset(train_dataset_name, source_model, target_model, align_dimension=False, embedding_path=embedding_path)
    train_dataset_emb = train_dataset.source_embeddings
    logger.info(f"Loaded train dataset embeddings: {train_dataset_emb.shape}")
    logger.info(f"D0 indices: {len(train_dataset.d0_index)} samples")
    logger.info(f"Loading test dataset {part} queries...")
    query_dataset = MemmapEmbeddingDataset(test_dataset_name, source_model, target_model, align_dimension=False, embedding_path=embedding_path)
    query_embeddings, query_indices = None, None
    if part == "d0":
        query_embeddings = query_dataset.source_embeddings
        query_indices = query_dataset.d0_index
    elif part == "d1":
        query_embeddings = query_dataset.source_embeddings
        query_indices = query_dataset.d1_index
    elif part == "query":
        query_embeddings = query_dataset.source_query_embeddings
        query_indices = None
    else:
        raise ValueError(f"Invalid part: {part}. Must be 'd0' or 'd1' or 'query'")
    logger.info(f"Loaded query embeddings: {query_embeddings.shape}")
    logger.info(f"Query embeddings stats - min: {query_embeddings.min():.6f}, max: {query_embeddings.max():.6f}, mean: {query_embeddings.mean():.6f}")
    if sample_queries is not None:
        query_embeddings, query_indices = random_sample_vectors(query_embeddings, query_indices, sample_queries)
        logger.info(f"After query sampling: {query_embeddings.shape}")
        if query_indices is not None:
            logger.info(f"Sampled query indices: {len(query_indices)} samples")
        else:
            query_indices = np.arange(len(query_embeddings))
            logger.info(f"Created query indices for sampled queries: {len(query_indices)} samples")
    if sample_corpus is not None:
        train_dataset_emb, train_dataset.d0_index = random_sample_vectors(train_dataset_emb, train_dataset.d0_index, sample_corpus)
        logger.info(f"After corpus sampling: {train_dataset_emb.shape}")
        logger.info(f"Sampled D0 indices: {len(train_dataset.d0_index)} samples")
    logger.info(f"Searching for top-{top_k} results for each query...")
    search_results = []
    if use_milvus:
        client = MilvusClient(host=milvus_host, port=milvus_port)
        search_params = {
            "params": {"nprobe": 10}
        }
        top_k_indices, top_k_similarities, top_k_entities = calculate_cosine_similarities_milvus(
            client, collection_name, query_embeddings, query_indices, top_k, search_params, batch_size
        )
        top_k_d0_indices = top_k_indices
    else:
        if sample_queries is not None or sample_corpus is not None:
            top_k_indices, top_k_similarities = calculate_cosine_similarities_simple(
                query_embeddings, train_dataset_emb, top_k, device
            )
        else:
            optimal_query_batch_size, optimal_corpus_batch_size = get_optimal_batch_sizes(
                query_embeddings, train_dataset_emb, train_dataset.d0_index, query_indices, top_k, 
                query_batch_size, corpus_batch_size
            )
            top_k_indices, top_k_similarities = calculate_cosine_similarities_manual(
                query_embeddings, train_dataset_emb, train_dataset.d0_index, query_indices, top_k, 
                optimal_query_batch_size, optimal_corpus_batch_size, device
            )
        top_k_d0_indices = np.array([train_dataset.d0_index[top_k_indices[i]] for i in range(len(query_embeddings))])
        top_k_entities = np.array([])
    topk_similarity_mean = np.mean(top_k_similarities)
    logger.info(f"Mean top-{top_k} similarity: {topk_similarity_mean:.6f}")
    if np.isnan(topk_similarity_mean):
        logger.warning("Warning: Found NaN values in similarity calculations!")
        topk_similarity_mean = np.nanmean(top_k_similarities)
        logger.info(f"Using nanmean fallback: {topk_similarity_mean:.6f}")
    try:
        client = NocoClient()
        upload_data = {
            "train_dataset_name": train_dataset_name,
            "test_dataset_name": test_dataset_name,
            "source_model": source_model,
            "target_model": target_model,
            "metric": metric,
            "top_k": top_k,
            "topk_similarity_mean": float(topk_similarity_mean),
            "total_queries": len(query_embeddings),
            "successful_queries": len(query_embeddings),
            "calculation_method": 'milvus' if use_milvus else 'manual',
            "sample_queries": sample_queries,
            "sample_corpus": sample_corpus,
            "part": part
        }
        table_name = "cal_dis_results_cross"
        table_id = client.create_table(table_name, list(upload_data.keys()))
        result = client.upload_result(upload_data, table_id)
        if result:
            logger.info(f"Successfully uploaded results to NocoDB table: {table_name}")
        else:
            logger.warning("Failed to upload results to NocoDB")
    except Exception as e:
        logger.error(f"Error uploading to NocoDB: {str(e)}, the dataset pair is {train_dataset_name} -> {test_dataset_name}")
if __name__ == "__main__":
    cal_dis_app()
