import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Ensure repository root is on sys.path so `embed_trainer` can be imported when running via path
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))


def parse_ks(ks_str: str) -> List[int]:
    return [int(x) for x in ks_str.split(",") if x.strip()]


def build_loader(
    dataset: str,
    tools_path: str,
    queries_path: str,
    subset: str,
    device: Any,
    max_queries: Optional[int],
):
    """Mirror loader creation from compute_vanilla_rag for consistency."""
    from embed_trainer.universal_data_loader import (
        create_fiqa_data_loader,
        create_toolret_data_loader,
        create_ultratool_data_loader,
    )

    dataset = dataset.lower()
    if dataset == "toolret":
        return create_toolret_data_loader(
            tools_dataset_path=tools_path
            or "embeddings/toolret_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/toolret_queries_embedded",
            device=device,
            max_queries=max_queries,
            subset=subset or "code",
        )
    if dataset == "ultratool":
        return create_ultratool_data_loader(
            tools_dataset_path=tools_path
            or "embeddings/ultratool_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/ultratool_queries_embedded",
            device=device,
            max_queries=max_queries,
        )
    if dataset == "fiqa":
        return create_fiqa_data_loader(
            tools_dataset_path=tools_path or "embeddings/fiqa_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/fiqa_queries_embedded",
            device=device,
            max_queries=max_queries,
        )

    raise ValueError(
        f"Unsupported dataset '{dataset}'. Choose from: toolret, ultratool, fiqa"
    )


def compute_metrics_from_rankings(
    topk_indices: List[List[int]],
    all_correct_arms: List[List[int]],
    ks: List[int],
) -> Dict[str, float]:
    """Compute Recall@k and NDCG@k given predicted rankings and ground-truth arms.

    Args:
        topk_indices: For each query, list of predicted arm indices ranked from best to worst.
        all_correct_arms: For each query, list of correct arm indices.
        ks: List of k values to evaluate.
    """
    import math

    num_queries = len(topk_indices)
    if num_queries == 0:
        return {}

    rslt: Dict[str, float] = {}
    max_k = max(ks) if ks else 0
    if max_k <= 0:
        return {}

    # Precompute discounts 1/log2(rank+1)
    discounts = [1.0 / math.log2(i + 2) for i in range(max_k)]

    for k in ks:
        total_hit = 0.0
        total_recall = 0.0
        total_ndcg = 0.0
        for i in range(num_queries):
            pred_list = topk_indices[i][:k]
            true_set = set(all_correct_arms[i])
            if not true_set:
                continue

            # Hit@k (any-hit)
            intersection_count = sum(1 for p in pred_list if p in true_set)
            if intersection_count > 0:
                total_hit += 1.0

            # Recall@k (classical)
            per_query_recall = intersection_count / max(1, len(true_set))
            total_recall += per_query_recall

            # NDCG@k with binary relevance
            rel = [1.0 if p in true_set else 0.0 for p in pred_list]
            dcg = sum(r * discounts[idx] for idx, r in enumerate(rel))
            idcg = sum(discounts[j] for j in range(min(k, len(true_set))))
            ndcg = (dcg / idcg) if idcg > 0 else 0.0
            total_ndcg += ndcg

        rslt[f"hit_at_{k}"] = total_hit / num_queries
        rslt[f"recall_at_{k}"] = total_recall / num_queries
        rslt[f"ndcg_at_{k}"] = total_ndcg / num_queries

    return rslt


def bm25_retrieve(
    corpus: List[str],
    queries: List[str],
    k: int,
    stopwords: Optional[str] = "en",
    use_stemmer: bool = False,
) -> Tuple[List[List[int]], List[List[float]]]:
    """Run BM25 retrieval using bm25s over the given corpus and queries.

    Returns:
        tuple of (doc_indices_per_query, scores_per_query)
    """
    try:
        import bm25s  # type: ignore
    except Exception as e:
        raise RuntimeError(
            'bm25s is not installed. Install with: pip install "bm25s[full]"'
        ) from e

    stemmer = None
    if use_stemmer:
        try:
            import Stemmer  # type: ignore

            stemmer = Stemmer.Stemmer("english")
        except Exception:
            print(
                "Warning: PyStemmer not available. Proceeding without stemming.\n"
                "Install with: pip install PyStemmer"
            )
            stemmer = None

    # Tokenize corpus
    corpus_tokens = bm25s.tokenize(corpus, stopwords=stopwords, stemmer=stemmer)

    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)

    # Tokenize queries (batch)
    query_tokens = bm25s.tokenize(queries, stopwords=stopwords, stemmer=stemmer)

    results, scores = retriever.retrieve(query_tokens, k=k)

    # Convert to python lists
    doc_lists: List[List[int]] = []
    score_lists: List[List[float]] = []

    # bm25s returns numpy arrays or similar with shape (n_queries, k)
    n_queries = len(queries)
    for i in range(n_queries):
        # support both numpy arrays and lists
        row_docs = (
            [int(results[i, j]) for j in range(results.shape[1])]
            if hasattr(results, "shape")
            else [int(x) for x in results[i]]
        )
        row_scores = (
            [float(scores[i, j]) for j in range(scores.shape[1])]
            if hasattr(scores, "shape")
            else [float(x) for x in scores[i]]
        )
        doc_lists.append(row_docs)
        score_lists.append(row_scores)

    return doc_lists, score_lists


def main():
    parser = argparse.ArgumentParser(
        description="Compute BM25 baseline retrieval metrics on embedded datasets."
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        choices=["toolret", "ultratool", "fiqa"],
        help="Dataset to evaluate",
    )
    parser.add_argument(
        "--tools_path",
        type=str,
        default="",
        help="Optional override for tools dataset path (datasets.load_from_disk)",
    )
    parser.add_argument(
        "--queries_path",
        type=str,
        default="",
        help="Optional override for queries dataset path (datasets.load_from_disk)",
    )
    parser.add_argument(
        "--subset",
        type=str,
        default="",
        help="Subset for datasets that support it (e.g., ToolRet: code|web|customized)",
    )
    parser.add_argument(
        "--embedding_model",
        type=str,
        default="large",
        help=(
            "Embedding model key used only for consistent filtering via loader "
            "(expects column 'embedding_<model>')"
        ),
    )
    parser.add_argument(
        "--ks",
        type=str,
        default="1,3,5,10",
        help="Comma-separated list of k values for Recall@k and NDCG@k",
    )
    parser.add_argument(
        "--max_queries",
        type=int,
        default=None,
        help="Limit the number of queries for a quick run",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="auto",
        help="cpu | cuda | auto (only used by loader; BM25 runs on CPU)",
    )
    parser.add_argument(
        "--stopwords",
        type=str,
        default="en",
        help="Stopwords language for bm25s.tokenize (use '' to disable)",
    )
    parser.add_argument(
        "--use_stemmer",
        action="store_true",
        help="Use PyStemmer for stemming (requires 'pip install PyStemmer')",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="",
        help="Optional path to save metrics JSON (stdout always prints)",
    )

    args = parser.parse_args()

    import torch

    device = (
        torch.device("cuda")
        if (
            args.device == "cuda"
            or (args.device == "auto" and torch.cuda.is_available())
        )
        else torch.device("cpu")
    )

    print("=== BM25 Baseline ===")
    print(f"Dataset: {args.dataset}")
    print(f"Embedding model (for filtering): {args.embedding_model}")
    print(f"Device (loader tensors): {device}")
    if args.max_queries:
        print(f"Max queries: {args.max_queries}")
    if args.tools_path:
        print(f"Tools path override: {args.tools_path}")
    if args.queries_path:
        print(f"Queries path override: {args.queries_path}")
    if args.subset:
        print(f"Subset: {args.subset}")
    print(f"Stopwords: {args.stopwords!r}")
    print(f"Use stemmer: {args.use_stemmer}")

    # Build loader and load data to ensure consistent query/tool subsets and ground-truth mapping
    loader = build_loader(
        dataset=args.dataset,
        tools_path=args.tools_path,
        queries_path=args.queries_path,
        subset=args.subset,
        device=device,
        max_queries=args.max_queries,
    )

    (
        _query_embeddings,
        _initial_embeddings,
        _true_embeddings,
        query_correct_arms,
    ) = loader.load_data(
        embedding_model=args.embedding_model,
        true_embedding_model=args.embedding_model,
        add_noise=False,
    )

    # Get raw texts prepared by the loader
    corpus = getattr(loader, "tool_texts", None)
    queries = getattr(loader, "query_texts", None)
    if corpus is None or queries is None:
        raise RuntimeError(
            "Loader did not provide texts. Ensure universal_data_loader stores tool_texts and query_texts."
        )

    print(f"Loaded texts: queries={len(queries)}, corpus={len(corpus)}")

    ks = parse_ks(args.ks)
    max_k = max(ks) if ks else 0
    max_k = min(max_k, len(corpus))
    if max_k <= 0:
        print("No valid k values or empty corpus. Exiting.")
        return

    # Run BM25 retrieval
    doc_lists, _score_lists = bm25_retrieve(
        corpus=corpus,
        queries=queries,
        k=max_k,
        stopwords=(args.stopwords if args.stopwords else None),
        use_stemmer=args.use_stemmer,
    )

    metrics = compute_metrics_from_rankings(
        topk_indices=doc_lists,
        all_correct_arms=query_correct_arms,
        ks=ks,
    )

    print("\n=== Metrics (BM25) ===")
    for k in ks:
        hk = metrics.get(f"hit_at_{k}")
        rk = metrics.get(f"recall_at_{k}")
        nk = metrics.get(f"ndcg_at_{k}")
        parts = []
        if hk is not None:
            parts.append(f"hit@{k}: {hk:.4f}")
        if rk is not None:
            parts.append(f"recall@{k}: {rk:.4f}")
        if nk is not None:
            parts.append(f"ndcg@{k}: {nk:.4f}")
        if parts:
            print("\t".join(parts))

    if args.output:
        os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(
                {
                    "dataset": args.dataset,
                    "embedding_model": args.embedding_model,
                    "ks": ks,
                    "metrics": metrics,
                    "bm25": {
                        "stopwords": args.stopwords,
                        "use_stemmer": args.use_stemmer,
                    },
                },
                f,
                indent=2,
            )
        print(f"Saved metrics JSON to: {args.output}")


if __name__ == "__main__":
    main()
