import argparse
import json
import os

# Ensure repository root is on sys.path so `embed_trainer` can be imported when running via path
import sys
from pathlib import Path
from typing import Any, Dict, List

_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

# Data loader helpers will be imported lazily in main to avoid requiring torch for --help


def parse_ks(ks_str: str) -> List[int]:
    return [int(x) for x in ks_str.split(",") if x.strip()]


def compute_retrieval_metrics(
    query_embeddings: Any,
    arm_embeddings: Any,
    all_correct_arms: List[List[int]],
    ks: List[int],
    device: Any,
) -> Dict[str, float]:
    """
    Mirror the evaluate_policy() logic in universal_contextual_bandit.py
    using raw dot-product scores between queries and arms (vanilla RAG).
    """
    import torch

    scores = torch.mm(query_embeddings, arm_embeddings.t())  # [N, K]
    num_queries = query_embeddings.size(0)

    if num_queries == 0:
        return {}

    rslt: Dict[str, float] = {}

    # Determine the maximum k needed for pre-computation
    max_k = max(ks) if ks else 0
    max_k = min(max_k, arm_embeddings.size(0))
    if max_k <= 0:
        return {}

    # Get top-k indices for the largest k needed
    _, topk_indices = torch.topk(scores, k=max_k, dim=1)  # [N, max_k]

    # Precompute discount terms for DCG: 1 / log2(rank + 1)
    discounts = 1.0 / torch.log2(
        torch.arange(2, max_k + 2, device=device, dtype=torch.float32)
    )

    for k in ks:
        if k > max_k:
            continue

        current_topk_indices = topk_indices[:, :k]
        current_discounts = discounts[:k]

        total_hit = 0.0
        total_recall = 0.0
        total_ndcg = 0.0

        for i in range(num_queries):
            pred_list = current_topk_indices[i].tolist()
            true_set = set(all_correct_arms[i])

            if not true_set:
                continue

            # Hit@k (any-hit)
            intersection_count = sum(
                1 for pred_item in pred_list if pred_item in true_set
            )
            if intersection_count > 0:
                total_hit += 1.0

            # Recall@k (classical)
            per_query_recall = intersection_count / max(1, len(true_set))
            total_recall += per_query_recall

            # NDCG@k
            relevance = torch.tensor(
                [
                    1.0 if pred_item in true_set else 0.0
                    for pred_item in pred_list
                ],
                device=device,
            )
            dcg = torch.sum(relevance * current_discounts)

            num_correct = len(true_set)
            idcg = torch.sum(discounts[: min(k, num_correct)])
            total_ndcg += (dcg / idcg).item() if idcg.item() > 0 else 0.0

        rslt[f"hit_at_{k}"] = total_hit / num_queries
        rslt[f"recall_at_{k}"] = total_recall / num_queries
        rslt[f"ndcg_at_{k}"] = total_ndcg / num_queries

    return rslt


def build_loader(
    dataset: str,
    tools_path: str,
    queries_path: str,
    subset: str,
    device: Any,
    max_queries: int | None,
):
    # Import data loader functions
    from embed_trainer.universal_data_loader import (
        create_fiqa_data_loader,
        create_toolret_data_loader,
        create_ultratool_data_loader,
    )

    dataset = dataset.lower()
    if dataset == "toolret":
        return create_toolret_data_loader(
            tools_dataset_path=tools_path
            or "embeddings/toolret_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/toolret_queries_embedded",
            device=device,
            max_queries=max_queries,
            subset=subset or "code",
        )
    if dataset == "ultratool":
        return create_ultratool_data_loader(
            tools_dataset_path=tools_path
            or "embeddings/ultratool_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/ultratool_queries_embedded",
            device=device,
            max_queries=max_queries,
        )
    if dataset == "fiqa":
        return create_fiqa_data_loader(
            tools_dataset_path=tools_path or "embeddings/fiqa_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/fiqa_queries_embedded",
            device=device,
            max_queries=max_queries,
        )

    raise ValueError(
        f"Unsupported dataset '{dataset}'. Choose from: toolret, ultratool, fiqa"
    )


def main():
    parser = argparse.ArgumentParser(
        description="Compute vanilla RAG baseline metrics using pre-embedded datasets."
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        choices=["toolret", "ultratool", "fiqa"],
        help="Dataset to evaluate",
    )
    parser.add_argument(
        "--tools_path",
        type=str,
        default="",
        help="Optional override for tools dataset path (datasets.load_from_disk)",
    )
    parser.add_argument(
        "--queries_path",
        type=str,
        default="",
        help="Optional override for queries dataset path (datasets.load_from_disk)",
    )
    parser.add_argument(
        "--subset",
        type=str,
        default="",
        help="Subset for datasets that support it (e.g., ToolRet: code|web|customized)",
    )
    parser.add_argument(
        "--embedding_model",
        type=str,
        default="large",
        help="Embedding model key to use (expects column 'embedding_<model>')",
    )
    parser.add_argument(
        "--ks",
        type=str,
        default="1,3,5,10",
        help="Comma-separated list of k values for Recall@k and NDCG@k",
    )
    parser.add_argument(
        "--max_queries",
        type=int,
        default=None,
        help="Limit the number of queries for a quick run",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="auto",
        help="cpu | cuda | auto",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="",
        help="Optional path to save metrics JSON (stdout always prints)",
    )

    args = parser.parse_args()

    import torch

    device = (
        torch.device("cuda")
        if (
            args.device == "cuda"
            or (args.device == "auto" and torch.cuda.is_available())
        )
        else torch.device("cpu")
    )

    print("=== Vanilla RAG Baseline ===")
    print(f"Dataset: {args.dataset}")
    print(f"Embedding model: {args.embedding_model}")
    print(f"Device: {device}")
    if args.max_queries:
        print(f"Max queries: {args.max_queries}")
    if args.tools_path:
        print(f"Tools path override: {args.tools_path}")
    if args.queries_path:
        print(f"Queries path override: {args.queries_path}")
    if args.subset:
        print(f"Subset: {args.subset}")

    loader = build_loader(
        dataset=args.dataset,
        tools_path=args.tools_path,
        queries_path=args.queries_path,
        subset=args.subset,
        device=device,
        max_queries=args.max_queries,
    )

    # Load data: we only need initial (raw) embeddings for arms and query embeddings
    (
        query_embeddings,
        initial_embeddings,
        _true_embeddings,
        query_correct_arms,
    ) = loader.load_data(
        embedding_model=args.embedding_model,
        true_embedding_model=args.embedding_model,  # same to avoid any mismatch
        add_noise=False,
    )

    print(
        f"Loaded tensors: queries={tuple(query_embeddings.shape)}, arms={tuple(initial_embeddings.shape)}"
    )

    ks = parse_ks(args.ks)

    metrics = compute_retrieval_metrics(
        query_embeddings=query_embeddings,
        arm_embeddings=initial_embeddings,
        all_correct_arms=query_correct_arms,
        ks=ks,
        device=device,
    )

    print("\n=== Metrics (vanilla embeddings) ===")
    for k in ks:
        hk = metrics.get(f"hit_at_{k}")
        rk = metrics.get(f"recall_at_{k}")
        nk = metrics.get(f"ndcg_at_{k}")
        parts = []
        if hk is not None:
            parts.append(f"hit@{k}: {hk:.4f}")
        if rk is not None:
            parts.append(f"recall@{k}: {rk:.4f}")
        if nk is not None:
            parts.append(f"ndcg@{k}: {nk:.4f}")
        if parts:
            print("\t".join(parts))

    if args.output:
        os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(
                {
                    "dataset": args.dataset,
                    "embedding_model": args.embedding_model,
                    "ks": ks,
                    "metrics": metrics,
                },
                f,
                indent=2,
            )
        print(f"Saved metrics JSON to: {args.output}")


if __name__ == "__main__":
    main()
