import pickle
import random
import time
import os
from typing import Any, Dict, List, Union
from torch.utils.data import dataset
import torch
import numpy as np
from loguru import logger
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from ..config.models import ManyToOneConfig, SingleRunConfig, CrossTranslateConfig
from ..embeddings.memmap_dataset import MultiMemmapDatasetLoader, MemmapEmbeddingDataset
from .multi_train_pipeline import (
    evaluate_test_dataset_step,
    print_results_step,
    set_seed,
)
from ..mapper.strategy import create_gating_moe_mapper
from ..mapper.strategy import create_mapper
import os
import numpy as np
from typing import List, Dict, Any
from tqdm import tqdm
from loguru import logger
class USearchEvaluator:
    def __init__(
        self,
        corpus_emb_paths: List[str],
        query_emb: np.ndarray,
        q2a: Dict[int, List[int]],
        k_list: List[int] = [10, 50, 100, 500, 1000],
        seed: int = 42,
        batch_size: int = 100_000,
        metric: str = "cos",
        exact: bool = True,
    ):
        from usearch.index import Index
        self.Index = Index
        self.corpus_emb_paths = corpus_emb_paths
        self.query_emb = np.asarray(query_emb, dtype=np.float32)
        self.q2a = q2a
        self.k_list = k_list
        self.max_k = max(k_list)
        self.rng = np.random.default_rng(seed)
        self.batch_size = batch_size
        self.metric = metric
        self.exact = exact
        self.index = self._build_index()
    def _memmap_npy(self, path: str):
        with open(path, "rb") as f:
            np.lib.format.read_magic(f)
            shape, _, dtype = np.lib.format.read_array_header_1_0(f)
            offset = f.tell()
        mm = np.memmap(path, dtype=dtype or np.float32, mode="r", shape=shape, offset=offset)
        return mm, shape
    def _balanced_random_router(self, n_docs: int, n_mappers: int) -> np.ndarray:
        idx = np.arange(n_docs)
        self.rng.shuffle(idx)
        router = np.empty(n_docs, dtype=np.int32)
        for m, part in enumerate(np.array_split(idx, n_mappers)):
            router[part] = m
        return router
    def _build_index(self):
        memmaps = []
        shapes = []
        for p in self.corpus_emb_paths:
            mm, shape = self._memmap_npy(p)
            memmaps.append(mm)
            shapes.append(shape)
        n_docs, dim = shapes[0]
        if any(s != (n_docs, dim) for s in shapes):
            raise ValueError(f"All corpus embeddings must have same shape, got {shapes}")
        n_mappers = len(memmaps)
        router = self._balanced_random_router(n_docs, n_mappers)
        index = self.Index(ndim=dim, metric=self.metric, dtype="f32")
        for s in tqdm(range(0, n_docs, self.batch_size), desc="Indexing merged corpus"):
            e = min(s + self.batch_size, n_docs)
            ids = np.arange(s, e, dtype=np.uint32)
            r = router[s:e]
            batch = np.empty((e - s, dim), dtype=np.float32)
            for m in range(n_mappers):
                mask = (r == m)
                if mask.any():
                    rows = np.nonzero(mask)[0]
                    batch[rows] = memmaps[m][s:e][rows].astype(np.float32, copy=False)
            index.add(ids, batch)
        for mm in memmaps:
            del mm
        logger.info(f"✓ Built merged index: N={n_docs}, D={dim}, mappers={n_mappers}")
        return index
    def _search(self, k: int) -> np.ndarray:
        q = self.query_emb
        out = np.empty((q.shape[0], k), dtype=np.int64)
        for i in range(q.shape[0]):
            m = self.index.search(q[i].reshape(1, -1), k, exact=self.exact)
            keys = np.asarray(m.keys, dtype=np.int64)
            out[i] = keys[:k] if keys.size >= k else np.pad(keys, (0, k - keys.size), constant_values=-1)
        return out
    def _recall_at_k(self, sr: np.ndarray, k: int) -> float:
        hit, tot = 0, 0
        for qid, pos in self.q2a.items():
            if not pos:
                continue
            tot += 1
            topk = sr[qid, :k]
            if any(p in topk for p in pos):
                hit += 1
        return hit / tot if tot else 0.0
    def evaluate(self) -> Dict[str, float]:
        res = {}
        sr = self._search(self.max_k)
        for k in self.k_list:
            res[f"recall@{k}"] = self._recall_at_k(sr, k)
        return res
def _load_single_dataset(config: ManyToOneConfig, single_dataset_test_mapped_path_list: List[List[str]]) -> List[str]:
    emb_list = []
    for test_mapped_path in single_dataset_test_mapped_path_list:
        emb = np.load(test_mapped_path, mmap_mode='r')
        emb_list.append(emb)
    index = np.arange(emb_list[0].shape[0])
    np.random.shuffle(index)
    for mapper_index, mapper_emb in enumerate(emb_list):
        pass
class FusionEmbResults:
    def __init__(self, mapping_emb_path_list: List[str]):
        self.mapping_emb_path_list = mapping_emb_path_list
        self.emb_list = self._load_all_mapping_embs()
        self.router = self._average_shuffle_embs_and_re_index()
    def __getitem__(self, idx: int) -> np.ndarray:
        return self.emb_list[self.router[idx]]
    def _load_all_mapping_embs(self) -> List[np.ndarray]:
        emb_list = []
        for mapping_emb_path in self.mapping_emb_path_list:
            emb = np.load(mapping_emb_path, mmap_mode='r')
            emb_list.append(emb)
        assert all(emb.shape == emb_list[0].shape for emb in emb_list), "All mapping embs should have the same shape"
        return emb_list
    def _average_shuffle_embs_and_re_index(self) -> np.ndarray:
        num_docs = self.emb_list[0].shape[0]
        num_mappers = len(self.emb_list)
        all_idx = np.arange(num_docs)
        np.random.shuffle(all_idx)
        router = np.empty(num_docs, dtype=np.int32)
        for mapper_index, part in enumerate(np.array_split(all_idx, num_mappers)):
            router[part] = mapper_index
        return router
def _average_retrieve_single_dataset(config: ManyToOneConfig, single_dataset_test_mapped_path_list: List[str], if_add_original: bool = False, dataset_name: str = None) -> FusionEmbResults:
    if if_add_original:
        original_emb_path = MemmapEmbeddingDataset.get_embedding_path('corpus', config.target_model, dataset_name, config.dataset.embedding_path)
    single_dataset_test_mapped_path_list.insert(0, original_emb_path)
    fusion_emb_results = FusionEmbResults(single_dataset_test_mapped_path_list)
    return fusion_emb_results
def multi_retrieve_step(config: ManyToOneConfig, test_mapped_path_list: List[List[str]]) -> None:
    num_runs, num_test_datasets = len(config.runs), len(config.dataset.test_dataset_list)
    assert num_runs == len(test_mapped_path_list), "The number of runs and test_mapped_path_list must be the same"
    fusion_emb_results_list = []
    for test_dataset_index, dataset_name in enumerate(config.dataset.test_dataset_list):
        single_dataset_test_mapped_path_list = []
        for run_index, run_config in enumerate(config.runs):
            single_dataset_test_mapped_path_list.append(test_mapped_path_list[run_index][test_dataset_index])
        logger.info(f"Dataset: {dataset_name}")
        logger.info(f"SingleDatasetTestMappedPathList: {single_dataset_test_mapped_path_list}")
        fusion_emb_results = _average_retrieve_single_dataset(config, single_dataset_test_mapped_path_list, if_add_original=True, dataset_name=dataset_name)
        dataset = MemmapEmbeddingDataset(
            config=config,
            dataset_name=dataset_name,
            source_model=config.runs[0].model.source_model,
            target_model=config.runs[0].model.target_model,
            d0_ratio=1/3,
            split_strategy="random"
        )
        evaluator = USearchEvaluator(
            corpus_emb_paths=fusion_emb_results.mapping_emb_path_list,
            query_emb=dataset.target_query_embeddings,
            q2a=dataset.q2a,
            k_list=[10, 50, 100, 500, 1000],
        )
        results = evaluator.evaluate()
        logger.info(f"Results: {results}")
def _fit_mapper(config: SingleRunConfig) -> bytes:
    train_dataloader = MultiMemmapDatasetLoader(
        config=config,
        dataset_names=config.dataset.train_dataset_list,
        batch_size=config.mapper.gating_moe.mapper_config.batch_size,
        shuffle=False,
        num_workers=4,
        device='cuda',
        pin_memory=True
    )
    mapper = create_mapper(config.mapper, 
        input_dim=train_dataloader.source_embedding_dim, 
        output_dim=train_dataloader.target_embedding_dim)
    mapper.fit(train_dataloader)
    return pickle.dumps(mapper)
def _transform_test_dataset(
    config: SingleRunConfig, 
    serialized_mapper: bytes, 
    test_dataset_names: List[str]
) -> List[str]:
    mapper = pickle.loads(serialized_mapper)
    test_mapped_path_list: List[str] = []
    cache_path = os.path.join(
        config.mapper.transformed_cache_path, 
        f"{config.model.source_model}_{config.model.target_model}/", 
        f"{int(time.time())}"
    )
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)
    for idx, dataset_name in enumerate(test_dataset_names):
        test_dataset_loader = MultiMemmapDatasetLoader(
            config=config,
            dataset_names=[dataset_name],
            batch_size=config.mapper.transformed_batch_size,
            shuffle=False,
            num_workers=0,
            device='cuda',
            pin_memory=True
        )
        output_path = os.path.join(cache_path, f"{dataset_name}.npy")
        mapper.transform_dataset(config, test_dataset_loader, output_path)
        arr = np.load(output_path, mmap_mode='r')
        if arr.ndim != 2:
            raise ValueError(f"Expected 2D array, got shape {arr.shape}")
        logger.info(f"✓ Verified: {output_path}, shape={arr.shape}")
        test_mapped_path_list.append(output_path)
    return test_mapped_path_list
def _run_single_mapping(run_config: SingleRunConfig, target_model: str) -> Dict[str, Any]:
    source_model = run_config.model.source_model
    logger.info(f"Processing run: {source_model}→{target_model}")
    config = run_config
    logger.info(f"Training mapper for {source_model}→{target_model}")
    serialized_mapper = _fit_mapper(config)
    if hasattr(config, 'dataset') and config.dataset and hasattr(config.dataset, 'test_dataset_list') and config.dataset.test_dataset_list:
        test_dataset_names = config.dataset.test_dataset_list
    else:
        test_dataset_names = [config.test_dataset]
    logger.info(f"Transforming test datasets for {source_model}→{target_model}")
    test_mapped_path_list = _transform_test_dataset(config, serialized_mapper, test_dataset_names)
    return {
        "source_model": source_model,
        "target_model": target_model,
        "test_mapped_path_list": test_mapped_path_list
    }
def run_all_mappings_step(config: ManyToOneConfig) -> List[List[str]]:
    target_model = config.target_model
    all_test_mapped_paths: List[List[str]] = []
    for i, run_config in enumerate(config.runs):
        logger.info(f"Running mapping {i+1}/{len(config.runs)}")
        result = _run_single_mapping(run_config, target_model)
        all_test_mapped_paths.append(result["test_mapped_path_list"])
    return all_test_mapped_paths
def collect_results_step(all_results: List[Dict[str, Any]]) -> Dict[str, Any]:
    console = Console()
    table = Table(title="Many-to-One Training Summary", show_header=True, header_style="bold magenta")
    table.add_column("Source Model", style="cyan")
    table.add_column("Target Model", style="green")
    table.add_column("Status", style="yellow")
    summary = {}
    for result in all_results:
        source = result.get("source_model", "unknown")
        target = result.get("target_model", "unknown")
        status = "success" if "results" in result else "failed"
        table.add_row(source, target, status)
        summary[source] = result
    console.print(table)
    return {
        "summary": summary,
        "total_runs": len(all_results),
        "successful_runs": len([r for r in all_results if "results" in r]),
    }
def many_to_one_pipeline(config: Union[ManyToOneConfig, SingleRunConfig]) -> Dict[str, Any]:
    logger.info(f"Starting Many-to-One pipeline")
    logger.info(f"Target Model: {config.target_model}")
    logger.info(f"Number of runs: {len(config.runs)}")
    set_seed()
    test_mapped_path_list_list = run_all_mappings_step(config)
    summary = multi_retrieve_step(config, test_mapped_path_list_list)
    return summary
