import os
import json
import hashlib
import pickle
from typing import Any, Dict, List, Tuple
import numpy as np
from loguru import logger
from ..config.models import TransitivityConfig, SingleRunConfig
from ..embeddings.memmap_dataset import MultiMemmapDatasetLoader, MemmapEmbeddingDataset
from ..mapper.strategy import create_mapper
from .multi_train_pipeline import set_seed
def stable_hash(obj: Dict[str, Any], n: int = 12) -> str:
    s = json.dumps(obj, sort_keys=True, default=str).encode("utf-8")
    return hashlib.sha1(s).hexdigest()[:n]
def edge(src: str, tgt: str) -> str:
    return f"{src}__TO__{tgt}"
def apply_mapper(mapper_bytes: bytes, x: np.ndarray, bs: int = 8192) -> np.ndarray:
    mapper = pickle.loads(mapper_bytes)
    if hasattr(mapper, "transform_embeddings"):
        fn = mapper.transform_embeddings
    elif hasattr(mapper, "transform"):
        fn = mapper.transform
    else:
        fn = mapper
    x = np.asarray(x, dtype=np.float32)
    out = []
    for s in range(0, x.shape[0], bs):
        out.append(np.asarray(fn(x[s : s + bs]), dtype=np.float32))
    return np.vstack(out)
def recall_at_k(ranked: np.ndarray, q2a: Dict[int, List[int]], k: int) -> float:
    hit = tot = 0
    for qid, pos in q2a.items():
        if not pos:
            continue
        tot += 1
        topk = ranked[qid, :k]
        hit += any(p in topk for p in pos)
    return hit / tot if tot else 0.0
def cosine_dist(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    a = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-12)
    b = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-12)
    return 1.0 - np.sum(a * b, axis=1)
def build_usearch(corpus: np.ndarray, metric: str = "cos"):
    from usearch.index import Index
    corpus = np.asarray(corpus, dtype=np.float32)
    n, d = corpus.shape
    index = Index(ndim=d, metric=metric, dtype="f32")
    index.add(np.arange(n, dtype=np.uint32), corpus)
    return index
def search_usearch(index, queries: np.ndarray, k: int, exact: bool = True) -> np.ndarray:
    q = np.asarray(queries, dtype=np.float32)
    out = np.empty((q.shape[0], k), dtype=np.int64)
    for i in range(q.shape[0]):
        m = index.search(q[i].reshape(1, -1), k, exact=exact)
        out[i] = np.asarray(m.keys, dtype=np.int64)[:k]
    return out
def train_edges(cfg: TransitivityConfig) -> Dict[str, bytes]:
    cfg.share_settings_to_runs()
    edge_to_run: Dict[str, SingleRunConfig] = {}
    for case in cfg.cases:
        for run in (case.run_ab, case.run_bc, case.run_ac):
            edge_to_run.setdefault(edge(run.model.source_model, run.model.target_model), run)
    trained: Dict[str, bytes] = {}
    for ek, run in edge_to_run.items():
        logger.info(f"[TRAIN] {ek}")
        train_dl = MultiMemmapDatasetLoader(
            config=run,
            dataset_names=run.dataset.train_dataset_list,
            batch_size=run.mapper.gating_moe.mapper_config.batch_size,
            shuffle=False,
            num_workers=4,
            device="cuda",
            pin_memory=True,
        )
        mapper = create_mapper(run.mapper, input_dim=train_dl.source_embedding_dim, output_dim=train_dl.target_embedding_dim)
        mapper.fit(train_dl)
        trained[ek] = pickle.dumps(mapper)
    return trained
def transform_corpus(
    cfg: TransitivityConfig,
    trained: Dict[str, bytes],
    cache_dir: str = "./cache/transitivity",
) -> Dict[str, Dict[str, Dict[str, str]]]:
    cfg.share_settings_to_runs()
    os.makedirs(cache_dir, exist_ok=True)
    mapped: Dict[str, Dict[str, Dict[str, str]]] = {}
    for ci, case in enumerate(cfg.cases):
        a = case.run_ab.model.source_model
        b = case.run_ab.model.target_model
        c = case.run_ac.model.target_model
        ek_ab, ek_bc, ek_ac = edge(a, b), edge(b, c), edge(a, c)
        case_key = f"case_{ci}"
        mapped[case_key] = {}
        for dataset_name in cfg.dataset.test_dataset_list:
            ds = MemmapEmbeddingDataset(
                config=cfg,
                dataset_name=dataset_name,
                source_model=a,
                target_model=c,
                d0_ratio=1 / 3,
                split_strategy="random",
            )
            corpus_a = ds.source_embeddings
            spec = {"case": ci, "a": a, "b": b, "c": c, "dataset": dataset_name}
            out_dir = os.path.join(cache_dir, stable_hash(spec))
            os.makedirs(out_dir, exist_ok=True)
            direct_path = os.path.join(out_dir, f"{a}_to_{c}_corpus.npy")
            comp_path = os.path.join(out_dir, f"{a}_to_{b}_to_{c}_corpus.npy")
            if not os.path.exists(direct_path):
                np.save(direct_path, apply_mapper(trained[ek_ac], corpus_a))
            if not os.path.exists(comp_path):
                corpus_ab = apply_mapper(trained[ek_ab], corpus_a)
                np.save(comp_path, apply_mapper(trained[ek_bc], corpus_ab))
            mapped[case_key][dataset_name] = {
                "a": a, "b": b, "c": c,
                "direct": direct_path,
                "composed": comp_path,
            }
    return mapped
def eval_transitivity(
    cfg: TransitivityConfig,
    mapped: Dict[str, Dict[str, Dict[str, str]]],
    metric: str = "cos",
    exact: bool = True,
) -> Dict[str, Any]:
    cfg.share_settings_to_runs()
    report: Dict[str, Any] = {"cases": {}}
    for case_key, per_ds in mapped.items():
        report["cases"][case_key] = {}
        ci = int(case_key.split("_")[1])
        k_list = cfg.cases[ci].k_list
        max_k = max(k_list)
        for dataset_name, info in per_ds.items():
            a, b, c = info["a"], info["b"], info["c"]
            logger.info(f"[EVAL] {case_key} {dataset_name}: {a}->{c} vs {a}->{b}->{c}")
            corpus_direct = np.load(info["direct"])
            corpus_composed = np.load(info["composed"])
            ds = MemmapEmbeddingDataset(
                config=cfg,
                dataset_name=dataset_name,
                source_model=a,
                target_model=c,
                d0_ratio=1 / 3,
                split_strategy="random",
            )
            queries_c = ds.target_query_embeddings
            q2a = ds.q2a
            index_direct = build_usearch(corpus_direct, metric=metric)
            index_composed = build_usearch(corpus_composed, metric=metric)
            ranked_direct = search_usearch(index_direct, queries_c, max_k, exact=exact)
            ranked_composed = search_usearch(index_composed, queries_c, max_k, exact=exact)
            d = cosine_dist(corpus_direct, corpus_composed)
            emb = {"mean": float(d.mean()), "p95": float(np.percentile(d, 95))}
            retrieval = {}
            for k in k_list:
                r_d = recall_at_k(ranked_direct, q2a, k)
                r_c = recall_at_k(ranked_composed, q2a, k)
                retrieval[f"recall@{k}_direct"] = float(r_d)
                retrieval[f"recall@{k}_composed"] = float(r_c)
                retrieval[f"delta@{k}"] = float(r_c - r_d)
            report["cases"][case_key][dataset_name] = {
                "a": a, "b": b, "c": c,
                "embedding": emb,
                "retrieval": retrieval,
            }
    print(report)
    return report
def transitivity_pipeline(cfg: TransitivityConfig) -> Dict[str, Any]:
    set_seed()
    trained = train_edges(cfg)
    mapped = transform_corpus(cfg, trained)
    report = eval_transitivity(cfg, mapped)
    return report
