import argparse
import json
import math
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import torch

# Ensure repository root is on sys.path so `embed_trainer` can be imported when running via path
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))


# -----------------------------
# Utilities
# -----------------------------


def parse_ks(ks_str: str) -> List[int]:
    return [int(x) for x in ks_str.split(",") if x.strip()]


def chunked(iterable: Sequence[Any], n: int) -> Iterable[Sequence[Any]]:
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]


def compute_metrics_from_rankings(
    topk_indices: List[List[int]],
    all_correct_arms: List[List[int]],
    ks: List[int],
) -> Dict[str, float]:
    """Compute Recall@k and NDCG@k given predicted rankings and ground-truth arms."""
    num_queries = len(topk_indices)
    if num_queries == 0:
        return {}
    rslt: Dict[str, float] = {}
    max_k = max(ks) if ks else 0
    if max_k <= 0:
        return {}
    discounts = [1.0 / math.log2(i + 2) for i in range(max_k)]
    for k in ks:
        total_hit = 0.0
        total_recall = 0.0
        total_ndcg = 0.0
        for i in range(num_queries):
            pred_list = topk_indices[i][:k]
            true_set = set(all_correct_arms[i])
            if not true_set:
                continue
            intersection_count = sum(1 for p in pred_list if p in true_set)
            if intersection_count > 0:
                total_hit += 1.0
            per_query_recall = intersection_count / max(1, len(true_set))
            total_recall += per_query_recall
            rel = [1.0 if p in true_set else 0.0 for p in pred_list]
            dcg = sum(r * discounts[idx] for idx, r in enumerate(rel))
            idcg = sum(discounts[j] for j in range(min(k, len(true_set))))
            total_ndcg += (dcg / idcg) if idcg > 0 else 0.0
        rslt[f"hit_at_{k}"] = total_hit / num_queries
        rslt[f"recall_at_{k}"] = total_recall / num_queries
        rslt[f"ndcg_at_{k}"] = total_ndcg / num_queries
    return rslt


# -----------------------------
# Data loader factory (reuse existing baseline pattern)
# -----------------------------


def build_loader(
    dataset: str,
    tools_path: str,
    queries_path: str,
    subset: str,
    device: Any,
    max_queries: Optional[int],
):
    from embed_trainer.universal_data_loader import (
        create_fiqa_data_loader,
        create_toolret_data_loader,
        create_ultratool_data_loader,
    )

    dataset = dataset.lower()
    if dataset == "toolret":
        return create_toolret_data_loader(
            tools_dataset_path=tools_path
            or "embeddings/toolret_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/toolret_queries_embedded",
            device=device,
            max_queries=max_queries,
            subset=subset or "code",
        )
    if dataset == "ultratool":
        return create_ultratool_data_loader(
            tools_dataset_path=tools_path
            or "embeddings/ultratool_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/ultratool_queries_embedded",
            device=device,
            max_queries=max_queries,
        )
    if dataset == "fiqa":
        return create_fiqa_data_loader(
            tools_dataset_path=tools_path or "embeddings/fiqa_tools_embedded",
            queries_dataset_path=queries_path
            or "embeddings/fiqa_queries_embedded",
            device=device,
            max_queries=max_queries,
        )
    raise ValueError(
        f"Unsupported dataset '{dataset}'. Choose from: toolret, ultratool, fiqa"
    )


# -----------------------------
# Qwen3 Reranker Implementation (pair scoring)
# -----------------------------


@dataclass
class QwenConfig:
    model_id: str = "Qwen/Qwen3-Reranker-8B"
    max_length: int = 8192
    dtype: Optional[str] = "auto"  # "auto" | "fp16" | "bf16" | "fp32"
    attn_impl: Optional[str] = None  # e.g., "flash_attention_2"


class QwenReranker:
    def __init__(self, device: str, cfg: QwenConfig):
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        self.tokenizer = AutoTokenizer.from_pretrained(
            cfg.model_id, padding_side="left"
        )
        torch_dtype = None
        if cfg.dtype:
            if cfg.dtype == "fp16":
                torch_dtype = torch.float16
            elif cfg.dtype == "bf16":
                torch_dtype = torch.bfloat16
            elif cfg.dtype == "fp32":
                torch_dtype = torch.float32
            else:
                torch_dtype = None
        model_kwargs = (
            {"torch_dtype": torch_dtype} if torch_dtype is not None else {}
        )
        if cfg.attn_impl:
            model_kwargs["attn_implementation"] = cfg.attn_impl
        self.model = AutoModelForCausalLM.from_pretrained(
            cfg.model_id, **model_kwargs
        )
        self.model.to(device)
        self.model.eval()
        self.device = device
        self.cfg = cfg

        # prompt parts
        self.prefix = (
            "<|im_start|>system\n"
            "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
            'Note that the answer can only be "yes" or "no".'
            "<|im_end|>\n<|im_start|>user\n"
        )
        self.suffix = (
            "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        )
        self.prefix_tokens = self.tokenizer.encode(
            self.prefix, add_special_tokens=False
        )
        self.suffix_tokens = self.tokenizer.encode(
            self.suffix, add_special_tokens=False
        )
        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")

    @staticmethod
    def format_instruction(
        instruction: Optional[str], query: str, doc: str
    ) -> str:
        if instruction is None:
            instruction = "Given a web search query, retrieve relevant passages that answer the query"
        return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

    def _process_inputs(self, pairs: List[str]):
        # Follow demo: truncate to leave room for prefix/suffix
        inputs = self.tokenizer(
            pairs,
            padding=False,
            truncation="longest_first",
            return_attention_mask=False,
            max_length=self.cfg.max_length
            - len(self.prefix_tokens)
            - len(self.suffix_tokens),
        )
        # Prepend/append
        for i, ids in enumerate(inputs["input_ids"]):
            inputs["input_ids"][i] = (
                self.prefix_tokens + ids + self.suffix_tokens
            )
        inputs = self.tokenizer.pad(
            inputs,
            padding=True,
            return_tensors="pt",
            max_length=self.cfg.max_length,
        )
        for k in inputs:
            inputs[k] = inputs[k].to(self.model.device)
        return inputs

    @torch.no_grad()
    def score_pairs(
        self, queries: List[str], docs: List[str], instruction: Optional[str]
    ) -> List[float]:
        import torch

        pairs = [
            self.format_instruction(instruction, q, d)
            for q, d in zip(queries, docs)
        ]
        inputs = self._process_inputs(pairs)
        batch_scores = self.model(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, self.token_true_id]
        false_vector = batch_scores[:, self.token_false_id]
        two = torch.stack([false_vector, true_vector], dim=1)
        two = torch.nn.functional.log_softmax(two, dim=1)
        scores = two[:, 1].exp().tolist()
        return scores


# -----------------------------
# BGE Gemma Reranker
# -----------------------------


@dataclass
class BGERerankerConfig:
    model_id: str = "BAAI/bge-reranker-v2-gemma"
    use_fp16: bool = True
    use_bf16: bool = False


class BGEReranker:
    def __init__(self, device: str, cfg: BGERerankerConfig):
        # FlagEmbedding uses torch under the hood; it will pick current CUDA device
        import torch
        from FlagEmbedding import FlagLLMReranker  # type: ignore

        # Set device via CUDA_VISIBLE_DEVICES scoping handled by caller processes
        self.model = FlagLLMReranker(
            cfg.model_id, use_fp16=cfg.use_fp16, use_bf16=cfg.use_bf16
        )
        self.device = device
        self.cfg = cfg

    def score_pairs(
        self, queries: List[str], docs: List[str], instruction: Optional[str]
    ) -> List[float]:
        # BGE ignores instruction; it expects [query, passage] pairs
        pairs = [[q, d] for q, d in zip(queries, docs)]
        return self.model.compute_score(pairs)  # returns List[float]


# -----------------------------
# Inference worker
# -----------------------------


@dataclass
class WorkerSpec:
    worker_id: int
    device_str: str
    model_kind: str  # "qwen" or "bge"
    model_cfg: Dict[str, Any]
    queries: List[str]
    corpus: List[str]
    max_k: int
    doc_batch_size: int
    instruction: Optional[str]


def worker_run(spec: WorkerSpec) -> List[List[int]]:
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

    # Build reranker
    if spec.model_kind == "qwen":
        reranker = QwenReranker(spec.device_str, QwenConfig(**spec.model_cfg))
    elif spec.model_kind == "bge":
        reranker = BGEReranker(
            spec.device_str, BGERerankerConfig(**spec.model_cfg)
        )
    else:
        raise ValueError(f"Unknown model_kind: {spec.model_kind}")

    import numpy as np

    topk_lists: List[List[int]] = []
    for qi, q in enumerate(spec.queries):
        # stream over corpus in chunks, maintain top-K
        best_scores = np.full((0,), -1e30, dtype=np.float32)
        best_idx = np.full((0,), -1, dtype=np.int32)
        doc_total = len(spec.corpus)
        for start in range(0, doc_total, spec.doc_batch_size):
            end = min(start + spec.doc_batch_size, doc_total)
            docs_chunk = spec.corpus[start:end]
            queries_chunk = [q] * len(docs_chunk)
            scores = reranker.score_pairs(
                queries_chunk, docs_chunk, spec.instruction
            )
            sc = np.asarray(scores, dtype=np.float32)
            idx = np.arange(start, end, dtype=np.int32)
            # merge with previous best and keep top-K
            merged_scores = (
                sc
                if best_scores.size == 0
                else np.concatenate([best_scores, sc], axis=0)
            )
            merged_idx = (
                idx
                if best_idx.size == 0
                else np.concatenate([best_idx, idx], axis=0)
            )
            if merged_scores.size > spec.max_k:
                top_indices = np.argpartition(merged_scores, -spec.max_k)[
                    -spec.max_k :
                ]
                best_scores = merged_scores[top_indices]
                best_idx = merged_idx[top_indices]
            else:
                best_scores = merged_scores
                best_idx = merged_idx
        # sort final top-K descending
        order = np.argsort(-best_scores)
        best_idx = best_idx[order].tolist()
        topk_lists.append(best_idx)
    return topk_lists


# -----------------------------
# Main
# -----------------------------


def main():
    parser = argparse.ArgumentParser(
        description="Compute retrieval metrics using reranker models (Qwen3/BGE Gemma)."
    )
    # Data
    parser.add_argument(
        "--dataset", required=True, choices=["toolret", "ultratool", "fiqa"]
    )
    parser.add_argument("--tools_path", type=str, default="")
    parser.add_argument("--queries_path", type=str, default="")
    parser.add_argument("--subset", type=str, default="")
    parser.add_argument("--embedding_model", type=str, default="large")
    parser.add_argument("--max_queries", type=int, default=None)
    parser.add_argument("--ks", type=str, default="1,3,5,10")
    # Model
    parser.add_argument(
        "--model",
        type=str,
        default="qwen",
        choices=["qwen", "bge"],
        help="Which reranker to use: qwen=Qwen/Qwen3-Reranker-8B, bge=BAAI/bge-reranker-v2-gemma",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="",
        help="Optional override of model id. If empty, uses default for --model",
    )
    parser.add_argument(
        "--instruction",
        type=str,
        default="",
        help="Optional instruction for Qwen reranker; if empty, uses default",
    )
    # Performance
    parser.add_argument(
        "--devices",
        type=str,
        default="cuda:0",
        help="Comma-separated CUDA devices to use, e.g., cuda:0,cuda:1,cuda:2,cuda:3",
    )
    parser.add_argument("--doc_batch_size", type=int, default=64)
    parser.add_argument(
        "--num_workers", type=int, default=1, help="Use up to len(devices)"
    )
    # Qwen-specific
    parser.add_argument("--qwen_max_length", type=int, default=8192)
    parser.add_argument(
        "--qwen_dtype",
        type=str,
        default="auto",
        choices=["auto", "fp16", "bf16", "fp32"],
    )
    parser.add_argument("--qwen_attn_impl", type=str, default="")
    # BGE-specific
    parser.add_argument("--bge_fp16", action="store_true", default=True)
    parser.add_argument("--bge_bf16", action="store_true", default=False)
    # Output
    parser.add_argument("--output", type=str, default="")

    args = parser.parse_args()

    import torch

    # Resolve devices and workers
    device_list = [d.strip() for d in args.devices.split(",") if d.strip()]
    if args.num_workers <= 0:
        args.num_workers = 1
    args.num_workers = min(args.num_workers, len(device_list))

    # Loader device can be CPU; reranker runs on GPUs per worker
    loader_device = (
        torch.device("cuda")
        if ("cuda" in device_list[0] and torch.cuda.is_available())
        else torch.device("cpu")
    )

    print("=== Reranker Baseline ===")
    print(f"Dataset: {args.dataset}")
    print(f"Model: {args.model} ({args.model_id or 'default'})")
    print(f"Devices: {device_list[:args.num_workers]}")
    print(f"Doc batch size: {args.doc_batch_size}")

    # Build loader and load data to ensure consistent query/tool subsets and ground-truth mapping
    loader = build_loader(
        dataset=args.dataset,
        tools_path=args.tools_path,
        queries_path=args.queries_path,
        subset=args.subset,
        device=loader_device,
        max_queries=args.max_queries,
    )

    (
        _query_embeddings,
        _initial_embeddings,
        _true_embeddings,
        query_correct_arms,
    ) = loader.load_data(
        embedding_model=args.embedding_model,
        true_embedding_model=args.embedding_model,
        add_noise=False,
    )

    corpus: List[str] = getattr(loader, "tool_texts")
    queries: List[str] = getattr(loader, "query_texts")
    print(f"Loaded texts: queries={len(queries)}, corpus={len(corpus)}")

    ks = parse_ks(args.ks)
    max_k = max(ks) if ks else 0
    max_k = min(max_k, len(corpus))
    if max_k <= 0:
        print("No valid k values or empty corpus. Exiting.")
        return

    # Split queries across workers
    q_slices: List[List[str]] = []
    if args.num_workers == 1:
        q_slices = [queries]
    else:
        per = math.ceil(len(queries) / args.num_workers)
        for i in range(args.num_workers):
            q_slices.append(queries[i * per : (i + 1) * per])

    # Prepare model config dict
    if args.model == "qwen":
        model_cfg = {
            "model_id": args.model_id or "Qwen/Qwen3-Reranker-4B",
            "max_length": args.qwen_max_length,
            "dtype": args.qwen_dtype,
            "attn_impl": args.qwen_attn_impl or None,
        }
    else:
        model_cfg = {
            "model_id": args.model_id or "BAAI/bge-reranker-v2-gemma",
            "use_fp16": bool(args.bge_fp16),
            "use_bf16": bool(args.bge_bf16),
        }

    instruction = args.instruction or None

    # Run workers (multiprocessing to utilize multiple GPUs)
    from multiprocessing import get_context

    specs: List[WorkerSpec] = []
    for wi in range(args.num_workers):
        specs.append(
            WorkerSpec(
                worker_id=wi,
                device_str=device_list[wi],
                model_kind=args.model,
                model_cfg=model_cfg,
                queries=q_slices[wi],
                corpus=corpus,
                max_k=max_k,
                doc_batch_size=args.doc_batch_size,
                instruction=instruction,
            )
        )

    if args.num_workers == 1:
        topk_parts = [worker_run(specs[0])]
    else:
        ctx = get_context("spawn")
        with ctx.Pool(processes=args.num_workers) as pool:
            topk_parts = pool.map(worker_run, specs)

    # Reassemble in original order
    topk_indices: List[List[int]] = []
    for part in topk_parts:
        topk_indices.extend(part)

    # Compute metrics
    metrics = compute_metrics_from_rankings(
        topk_indices=topk_indices,
        all_correct_arms=query_correct_arms,
        ks=ks,
    )

    print("\n=== Metrics (Reranker) ===")
    for k in ks:
        hk = metrics.get(f"hit_at_{k}")
        rk = metrics.get(f"recall_at_{k}")
        nk = metrics.get(f"ndcg_at_{k}")
        parts = []
        if hk is not None:
            parts.append(f"hit@{k}: {hk:.4f}")
        if rk is not None:
            parts.append(f"recall@{k}: {rk:.4f}")
        if nk is not None:
            parts.append(f"ndcg@{k}: {nk:.4f}")
        if parts:
            print("\t".join(parts))

    if args.output:
        os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(
                {
                    "dataset": args.dataset,
                    "model": args.model,
                    "model_id": model_cfg["model_id"],
                    "ks": ks,
                    "metrics": metrics,
                    "num_workers": args.num_workers,
                    "devices": device_list[: args.num_workers],
                    "doc_batch_size": args.doc_batch_size,
                },
                f,
                indent=2,
            )
        print(f"Saved metrics JSON to: {args.output}")


if __name__ == "__main__":
    main()
