import pickle
import random
import time
from typing import Any, Dict, List, Union
import torch
import numpy as np
from loguru import logger
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
import os
from src.evaluation.evaluator import Evaluator
from ..config.models import CrossTranslateConfig, SingleRunConfig
from src.mapper.strategy.base import VectorMapper
from src.embeddings.memmap_dataset import MultiMemmapDatasetLoader
from ..noco_util.util import NocoClient
def cross_fit_mapper_step(cross_data: dict, config: CrossTranslateConfig):
    train_data = cross_data["train"]
    logger.info(f"Training mapper on {len(train_data['d0'])} reference samples from training dataset")
    mapper = fit_mapper_step(train_data, config)
    logger.info("Mapper training completed on training dataset")
    return mapper
def merge_results_step(
    evaluation_results: Dict[str, Any],
    tc_results: Dict[str, Any]
) -> Dict[str, Any]:
    if hasattr(evaluation_results, 'read'):
        eval_dict = evaluation_results.read()
    else:
        eval_dict = evaluation_results
    if hasattr(tc_results, 'read'):
        tc_dict = tc_results.read()
    else:
        tc_dict = tc_results
    if isinstance(eval_dict, dict) and "evaluation_results" in eval_dict:
        eval_dict = eval_dict["evaluation_results"]
    merged_results = {}
    merged_results.update(eval_dict)
    merged_results.update(tc_dict)
    logger.info(f"Merged {len(eval_dict)} evaluation metrics with {len(tc_dict)} TC metrics")
    return merged_results
def print_results_step(evaluation_results: Dict[str, Any]):
    if hasattr(evaluation_results, 'read'):
        results = evaluation_results.read()
    else:
        results = evaluation_results
    logger.info(f"Evaluation results: {results}")
    return results
def fit_mapper_step_multi(config: Union[CrossTranslateConfig, SingleRunConfig]):
    train_dataloader = MultiMemmapDatasetLoader(
        config=config,
        dataset_names=config.dataset.train_dataset_list,
        batch_size=config.mapper.gating_moe.mapper_config.batch_size,
        shuffle=False,
        num_workers=4,
        device='cuda',
        pin_memory=True
    )
    from ..mapper.strategy import create_mapper
    mapper = create_mapper(config.mapper, input_dim=train_dataloader.source_embedding_dim, output_dim=train_dataloader.target_embedding_dim)
    mapper.fit_multi(train_dataloader)
    serialized_mapper = pickle.dumps(mapper)
    return serialized_mapper
def transform_step(data: dict, mapper: VectorMapper):
    mapped_data = {
        "mapped_src_emb": mapper.transform(data["src_emb"]),
        "mapped_src_q": mapper.transform(data["src_q"]),
    }
    return mapped_data
def cross_transform_step(cross_data: dict, mapper: VectorMapper):
    result = {}
    if "train" in cross_data:
        logger.info("Transforming training embeddings")
        result["train"] = transform_step(cross_data["train"], mapper)
    if "test" in cross_data:
        logger.info("Transforming test embeddings")
        result["test"] = transform_step(cross_data["test"], mapper)
    return result
def evaluate_step(data: dict, mapped_emb, use_low_memory: bool = False):
    data["q2a"] = {int(k): v for k, v in data["q2a"].items()}
    return Evaluator(
        corpus_emb_1=data["src_emb"], 
        corpus_emb_2=data["tgt_emb"],
        query_emb_1=data["src_q"], 
        query_emb_2=data["tgt_q"],
        p_index_list=data["q2a"],
        d0=data["d0"], 
        d1=data["d1"], 
        d2=data["d2"],
        query_emb_1_transformed=mapped_emb["mapped_src_q"], 
        corpus_emb_1_transformed=mapped_emb["mapped_src_emb"],
        k_list=[10, 50, 100, 500, 1000]
    ).evaluate(use_low_memory=use_low_memory)
def cross_evaluate_step(cross_data: dict, cross_mapped: dict):
    prefixed_metrics = {}
    filtered_keys = []
    if "train" in cross_data and "train" in cross_mapped:
        logger.info("Evaluating on training dataset")
        train_metrics = evaluate_step(cross_data["train"], cross_mapped["train"], use_low_memory=True)
        for key, value in train_metrics.items():
            prefixed_metrics[f"train_{key}"] = value
        filtered_keys.append("train_our_method@100")
        filtered_keys.extend([
            "train_corpus_cosine_distance",
            "train_corpus_euclidean_distance"
        ])
    if "test" in cross_data and "test" in cross_mapped:
        logger.info("Evaluating on test dataset")
        test_metrics = evaluate_step(cross_data["test"], cross_mapped["test"])
        for key, value in test_metrics.items():
            prefixed_metrics[f"test_{key}"] = value
        filtered_keys.append("test_our_method@100")
        filtered_keys.extend([
            "test_corpus_cosine_distance",
            "test_corpus_euclidean_distance"
        ])
    from rich import print as rich_print
    filtered_metrics = {k: v for k, v in prefixed_metrics.items() if k in filtered_keys}
    rich_print(filtered_metrics)
    return prefixed_metrics
def fit_mapper_step(data: dict, config: CrossTranslateConfig):
    from ..mapper.strategy import (
        create_gating_moe_mapper,
        create_linear_mapper,
        create_procrustes_mapper,
        create_diffusion_mapper,
        create_spnt_mapper,
        create_spnt_diffusion_mapper,
        create_simple_la2m_mapper,
        create_simple_linear_mapper
    )
    src_emb = data["src_emb"]
    tgt_emb = data["tgt_emb"]
    if config.mapper.mapper_name == "gating-moe":
        mapper = create_gating_moe_mapper(config.mapper.gating_moe)
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "procrustes":
        mapper = create_procrustes_mapper(config.mapper.procrustes)
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "diffusion":
        mapper = create_diffusion_mapper(config.mapper.diffusion, src_emb.shape[1])
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "spnt-diffusion":
        mapper = create_spnt_diffusion_mapper(
            config.mapper.spnt, 
            config.mapper.diffusion, 
            src_emb.shape[1]
        )
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "spnt":
        mapper = create_spnt_mapper(
            config.mapper.spnt, 
            src_emb.shape[1], 
            tgt_emb.shape[1], 
            config.mapper.linear
        )
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "simple_linear":
        mapper = create_simple_linear_mapper(
            config.mapper.simple_linear, 
            src_emb.shape[1], 
            tgt_emb.shape[1]
        )
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "simple_la2m":
        mapper = create_simple_la2m_mapper(
            config.mapper.la2m, 
            config.mapper.linear
        )
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    elif config.mapper.mapper_name == "la2m":
        from ..mapper import LA2MMapper
        mapper = LA2MMapper()
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"])
    else:
        mapper = create_linear_mapper(config.mapper.linear)
        mapper.fit(src_emb, tgt_emb, reference_indices=data["d0"],
                   query_emb_1=data["src_q"], query_emb_2=data["tgt_q"])
    return mapper
def log_results_step(
    config: CrossTranslateConfig,
    train_results: Dict[str, Any], 
    test_results: Dict[str, Any]
) -> None:
    console = Console()
    train_metric = train_results.get("train_our_method@100", "N/A")
    test_metric = test_results.get("test_our_method@100", "N/A")
    table = Table(title="Cross-Translation Results", show_header=True, header_style="bold magenta")
    table.add_column("Dataset", style="cyan", no_wrap=True)
    table.add_column("our_method@100", style="green", justify="right")
    table.add_row("Train", f"{train_metric:.4f}" if isinstance(train_metric, (int, float)) else str(train_metric))
    table.add_row("Test", f"{test_metric:.4f}" if isinstance(test_metric, (int, float)) else str(test_metric))
    panel = Panel(table, title="[bold blue]Evaluation Metrics[/bold blue]", border_style="blue")
    console.print(panel)
    logger.info(f"Train our_method@100: {train_metric}, Test our_method@100: {test_metric}")
def upload_results_step(
    config: CrossTranslateConfig,
    train_results: Dict[str, Any], 
    test_results: Dict[str, Any]
) -> None:
    try:
        params = config.dict()
        combined_data = {}
        combined_data.update(train_results)
        combined_data.update(test_results)
        combined_data.update(params)
        client = NocoClient()
        table_id = client.create_table("cross_translate_results_use_gating_moe", list(combined_data))
        client.upload_result(combined_data, table_id)
        logger.info("Results uploaded to NocoDB successfully")
    except Exception as e:
        logger.warning(f"Failed to upload results to NocoDB: {e}")
def set_seed(seed: int = 42):
    import random
    random.seed(seed)            
    np.random.seed(seed)         
    torch.manual_seed(seed)      
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  
def create_train_dataset_loader(config: CrossTranslateConfig):
    return MultiMemmapDatasetLoader(
        config=config,
        dataset_names=config.dataset.train_dataset_list,
        batch_size=10000,
        shuffle=False,
        num_workers=0,
        device='cuda',
        pin_memory=True
    )
def transform_test_dataset(config: Union[CrossTranslateConfig, SingleRunConfig], serialized_mapper: bytes, test_dataset_names: List[str]):
    mapper = pickle.loads(serialized_mapper)
    test_mapped_path_list = []
    cache_path = os.path.join(config.mapper.transformed_cache_path, f"{config.model.source_model}_{config.model.target_model}/", f"{int(time.time())}")
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)
    for idx, dataset_name in enumerate(test_dataset_names):
        test_dataset_loader = MultiMemmapDatasetLoader(
            config=config,
            dataset_names=[dataset_name],
            batch_size=config.mapper.transformed_batch_size,
            shuffle=False,
            num_workers=0,
            device='cuda',
            pin_memory=True
        )
        output_path = os.path.join(cache_path, f"{dataset_name}.npy")
        mapper.transform_dataset(config, test_dataset_loader, output_path)
        try:
            arr = np.load(output_path, mmap_mode='r')
            if arr.ndim != 2:
                raise ValueError(f"Expected 2D array, got shape {arr.shape}")
            logger.info(f"✓ Verified: {output_path}, shape={arr.shape}")
        except Exception as e:
            logger.error(f"❌ Invalid output file: {output_path} - {e}")
            raise
        test_mapped_path_list.append(output_path)
    return test_mapped_path_list
def evaluate_test_dataset_step(
    config: Union[CrossTranslateConfig, SingleRunConfig], 
    test_mapped_path_list: List[str],
    serialized_mapper: bytes
) -> Dict[str, Any]:
    from ..evaluation.evaluator import USearchEvaluator
    mapper = pickle.loads(serialized_mapper)
    logger.info(f"Loading test datasets with memmap: {config.dataset.test_dataset_list}")
    test_loader = MultiMemmapDatasetLoader(
        config=config,
        dataset_names=config.dataset.test_dataset_list,
        batch_size=10000,
        shuffle=False,
        num_workers=0,
        device='cuda',
        pin_memory=True
    )
    corpus_emb_paths = test_mapped_path_list
    query_emb_list = []
    q2a_list = []
    corpus_expert_ids_paths = []
    for corpus_path in corpus_emb_paths:
        expert_ids_path = corpus_path.replace('.npy', '_expert_ids.npy')
        if os.path.exists(expert_ids_path):
            corpus_expert_ids_paths.append(expert_ids_path)
            logger.info(f"Found expert IDs file: {expert_ids_path}")
        else:
            corpus_expert_ids_paths.append(None)
    for idx, dataset in enumerate(test_loader.datasets):
        dataset_name = dataset.dataset_name
        logger.info(f"Extracting data for dataset {idx+1}/{len(test_loader.datasets)}: {dataset_name}")
        tgt_query = dataset.target_query_embeddings
        query_emb_list.append(tgt_query)
        q2a = dataset.q2a
        q2a_list.append(q2a)
        logger.info(
            f"✓ Prepared data for {dataset_name}: "
            f"{len(tgt_query)} queries, {len(q2a)} q2a mappings"
        )
    logger.info(f"Creating USearchEvaluator with {len(corpus_emb_paths)} corpus files")
    evaluator = USearchEvaluator(
        corpus_emb_paths=corpus_emb_paths,
        query_emb_list=query_emb_list,
        q2a_list=q2a_list,
        k_list=[10, 50, 100, 500, 1000],
        corpus_expert_ids_paths=corpus_expert_ids_paths if any(corpus_expert_ids_paths) else None
    )
    logger.info("Running evaluation...")
    results = evaluator.evaluate()
    logger.info(f"✓ Evaluation completed with {len(results)} metrics")
    return {"evaluation_results": results}
def flatten_dict(d: dict, parent_key: str = "", sep: str = "."):
    items = {}
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.update(flatten_dict(v, new_key, sep=sep))
        else:
            items[new_key] = v
    return items
def calculate_translation_confidence_step(
    config: Union[CrossTranslateConfig, SingleRunConfig],
    serialized_mapper: bytes
) -> Dict[str, Any]:
    logger.info("Calculating Translation Confidence (TC) and E(o) for all transformed embeddings...")
    mapper = pickle.loads(serialized_mapper)
    output_dir = os.path.join(
        config.mapper.transformed_cache_path,
        f"{config.model.source_model}_{config.model.target_model}",
        "tc_results"
    )
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"TC results will be saved to: {output_dir}")
    train_loader = MultiMemmapDatasetLoader(
        config=config,
        dataset_names=config.dataset.train_dataset_list,
        batch_size=10000,
        shuffle=False,
        num_workers=0,
        device='cuda',
        pin_memory=True
    )
    train_embeddings_list = []
    for dataset in train_loader.datasets:
        train_emb = dataset.source_embeddings
        train_embeddings_list.append(train_emb)
    X_train = np.vstack(train_embeddings_list).astype(np.float32)
    logger.info(f"Loaded {len(X_train)} training embeddings with dimension {X_train.shape[1]}")
    max_samples_for_sigma = min(5000, len(X_train))
    if len(X_train) > max_samples_for_sigma:
        logger.info(f"Sampling {max_samples_for_sigma} training samples for σ_data calculation")
        np.random.seed(42)
        sample_indices = np.random.choice(len(X_train), max_samples_for_sigma, replace=False)
        X_train_sample = X_train[sample_indices]
    else:
        X_train_sample = X_train
    logger.info("Computing pairwise distances for σ_data...")
    try:
        from scipy.spatial.distance import pdist
        logger.info("Using scipy.pdist for efficient pairwise distance calculation")
        pairwise_distances = pdist(X_train_sample, metric='cosine')
    except ImportError:
        logger.info("scipy not available, using manual batch calculation")
        pairwise_distances = []
        batch_size = 500
        for i in tqdm(range(0, len(X_train_sample), batch_size), desc="Computing pairwise distances"):
            end_i = min(i + batch_size, len(X_train_sample))
            batch_i = X_train_sample[i:end_i]
            for j in range(i, len(X_train_sample), batch_size):
                end_j = min(j + batch_size, len(X_train_sample))
                batch_j = X_train_sample[j:end_j]
                if i == j:
                    for idx_i in range(len(batch_i)):
                        for idx_j in range(idx_i + 1, len(batch_j)):
                            dist = np.linalg.norm(batch_i[idx_i] - batch_j[idx_j])
                            pairwise_distances.append(dist)
                else:
                    distances = np.linalg.norm(
                        batch_i[:, np.newaxis, :] - batch_j[np.newaxis, :, :],
                        axis=2
                    )
                    pairwise_distances.extend(distances.flatten().tolist())
        pairwise_distances = np.array(pairwise_distances)
    sigma_data = np.sqrt(np.var(pairwise_distances))
    logger.info(f"σ_data (intrinsic dispersion) = {sigma_data:.6f}")
    try:
        import faiss
        FAISS_AVAILABLE = True
    except ImportError:
        FAISS_AVAILABLE = False
        logger.warning("FAISS not available, using numpy implementation (slower)")
    if FAISS_AVAILABLE:
        train_index = faiss.IndexFlatL2(X_train.shape[1])
        train_index.add(X_train)
        logger.info("Built FAISS index for training set")
    all_dataset_names = config.dataset.train_dataset_list + config.dataset.test_dataset_list
    all_loader = MultiMemmapDatasetLoader(
        config=config,
        dataset_names=all_dataset_names,
        batch_size=10000,
        shuffle=False,
        num_workers=0,
        device='cuda',
        pin_memory=True
    )
    all_tc_results = {}
    saved_files = {}
    for dataset_idx, dataset in enumerate(all_loader.datasets):
        dataset_name = dataset.dataset_name
        is_train = dataset_name in config.dataset.train_dataset_list
        dataset_type = "train" if is_train else "test"
        logger.info(f"Processing {dataset_type} dataset {dataset_idx+1}/{len(all_loader.datasets)}: {dataset_name}")
        X_source = dataset.source_embeddings.astype(np.float32)
        Y_target = dataset.target_embeddings.astype(np.float32)
        logger.info(f"Loaded {len(X_source)} embeddings: source shape {X_source.shape}, target shape {Y_target.shape}")
        logger.info(f"Transforming embeddings for {dataset_name}...")
        X_transformed = mapper.transform(X_source)
        X_transformed = X_transformed.astype(np.float32)
        logger.info(f"Computing E(o) (translation error) for {dataset_name}...")
        E_values = np.linalg.norm(X_transformed - Y_target, axis=1)
        logger.info(f"Computing nearest neighbor distances (δ) for {dataset_name}...")
        if FAISS_AVAILABLE:
            batch_size_nn = 10000
            delta_values = []
            for i in tqdm(range(0, len(X_source), batch_size_nn), desc=f"NN search for {dataset_name}"):
                end_idx = min(i + batch_size_nn, len(X_source))
                batch_source = X_source[i:end_idx]
                distances, _ = train_index.search(batch_source, k=1)
                delta_values.extend(np.sqrt(distances.flatten()).tolist())
            delta_values = np.array(delta_values)
        else:
            delta_values = []
            batch_size_nn = 1000
            for i in tqdm(range(0, len(X_source), batch_size_nn), desc=f"NN search for {dataset_name}"):
                end_idx = min(i + batch_size_nn, len(X_source))
                batch_source = X_source[i:end_idx]
                distances = np.linalg.norm(
                    batch_source[:, np.newaxis, :] - X_train[np.newaxis, :, :],
                    axis=2
                )
                min_distances = np.min(distances, axis=1)
                delta_values.extend(min_distances.tolist())
            delta_values = np.array(delta_values)
        tc_values = np.exp(-delta_values / sigma_data)
        tc_file = os.path.join(output_dir, f"{dataset_name}_tc.npy")
        e_file = os.path.join(output_dir, f"{dataset_name}_E.npy")
        delta_file = os.path.join(output_dir, f"{dataset_name}_delta.npy")
        np.save(tc_file, tc_values)
        np.save(e_file, E_values)
        np.save(delta_file, delta_values)
        logger.info(f"✓ Saved TC scores to: {tc_file}")
        logger.info(f"✓ Saved E(o) values to: {e_file}")
        logger.info(f"✓ Saved δ values to: {delta_file}")
        saved_files[f"{dataset_name}_tc_file"] = tc_file
        saved_files[f"{dataset_name}_E_file"] = e_file
        saved_files[f"{dataset_name}_delta_file"] = delta_file
        tc_stats = {
            f"{dataset_name}_tc_mean": float(np.mean(tc_values)),
            f"{dataset_name}_tc_std": float(np.std(tc_values)),
            f"{dataset_name}_tc_min": float(np.min(tc_values)),
            f"{dataset_name}_tc_max": float(np.max(tc_values)),
            f"{dataset_name}_tc_median": float(np.median(tc_values)),
            f"{dataset_name}_tc_p25": float(np.percentile(tc_values, 25)),
            f"{dataset_name}_tc_p75": float(np.percentile(tc_values, 75)),
            f"{dataset_name}_E_mean": float(np.mean(E_values)),
            f"{dataset_name}_E_std": float(np.std(E_values)),
            f"{dataset_name}_E_min": float(np.min(E_values)),
            f"{dataset_name}_E_max": float(np.max(E_values)),
            f"{dataset_name}_E_median": float(np.median(E_values)),
            f"{dataset_name}_delta_mean": float(np.mean(delta_values)),
            f"{dataset_name}_delta_std": float(np.std(delta_values)),
            f"{dataset_name}_sigma_data": float(sigma_data),
            f"{dataset_name}_num_samples": int(len(X_source)),
        }
        all_tc_results.update(tc_stats)
        logger.info(
            f"✓ {dataset_name} TC: mean={tc_stats[f'{dataset_name}_tc_mean']:.4f}, "
            f"std={tc_stats[f'{dataset_name}_tc_std']:.4f}, "
            f"min={tc_stats[f'{dataset_name}_tc_min']:.4f}, "
            f"max={tc_stats[f'{dataset_name}_tc_max']:.4f}"
        )
        logger.info(
            f"✓ {dataset_name} E(o): mean={tc_stats[f'{dataset_name}_E_mean']:.4f}, "
            f"std={tc_stats[f'{dataset_name}_E_std']:.4f}, "
            f"min={tc_stats[f'{dataset_name}_E_min']:.4f}, "
            f"max={tc_stats[f'{dataset_name}_E_max']:.4f}"
        )
    import json
    metadata = {
        "sigma_data": float(sigma_data),
        "num_train_samples": int(len(X_train)),
        "output_dir": output_dir,
        "datasets": all_dataset_names
    }
    metadata_file = os.path.join(output_dir, "metadata.json")
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    logger.info(f"✓ Saved metadata to: {metadata_file}")
    all_tc_results["tc_output_dir"] = output_dir
    all_tc_results["tc_saved_files_count"] = len(saved_files)
    logger.info("✓ Translation Confidence and E(o) calculation completed")
    logger.info(f"TC results saved to: {output_dir}")
    return all_tc_results
def upload_results_to_noco_step(config: CrossTranslateConfig, results: Dict[str, Any]) -> None:
    from ..noco_util.util import NocoClient
    if hasattr(results, 'read'):
        results_dict = results.read()
    else:
        results_dict = results
    upload_data = config.dict()
    upload_data.update(results_dict)
    flatten_upload_data = flatten_dict(upload_data)
    client = NocoClient()
    table_id = client.create_table("main_results", list(flatten_upload_data.keys()))
    client.upload_result(flatten_upload_data, table_id)
    logger.info("Results uploaded to NocoDB successfully")
def multi_train_pipeline(config: CrossTranslateConfig) -> Dict[str, Any]:
    logger.info(f"Starting expert distribution calculation pipeline")
    logger.info(f"Train: {config.train_dataset} -> Test: {config.test_dataset} -> Expert: {config.mapper.gating_moe.num_experts}")
    logger.info(f"Models: {config.model.source_model} -> {config.model.target_model}")
    logger.info(f"Mapper: {config.mapper.mapper_name}")
    set_seed()
    serialized_mapper = fit_mapper_step_multi(config)
    test_dataset_names = config.dataset.test_dataset_list
    test_mapped_path_list = transform_test_dataset(config, serialized_mapper, test_dataset_names)
    evaluation_results = evaluate_test_dataset_step(config, test_mapped_path_list, serialized_mapper)
    tc_results = calculate_translation_confidence_step(config, serialized_mapper)
    merged_results = merge_results_step(evaluation_results, tc_results)
    final_results = print_results_step(merged_results)
    upload_results_to_noco_step(config, final_results)
