"""
Process MultiHop RAG reranking JSONL to produce final QA answers and accuracy.

Usage example:
    python -m embed_trainer.process_multihop_rerank \
        --rerank_jsonl path/to/rerank_info_20250101_120000.jsonl \
        --tools_dataset_path embeddings/mhr_tools_embedded \
        --queries_dataset_path embeddings/mhr_decomposed_embedded \
        --embedding_model large \
        --output_path runs/multihop_answers.jsonl

This script:
- Reads each line from the rerank jsonl (query_id, query_text, arm_id).
- Groups lines by main question (prefix before "Focus:" in formatted_query).
- Reconstructs chosen document texts using the same sorted/filtered indexing used in training.
- Calls an LLM to answer the main question using concatenated evidence.
- Looks up ground-truth answers at the same (sorted/filtered) query indices and computes exact-match accuracy.
- Writes a JSONL with details per main question.

Note:
- Indices in the rerank file reference the sorted/filtered datasets (as in UniversalBanditDataLoader with require_sort=True).
- To correctly map indices to texts/answers, we replicate sorting and filtering here.
"""

import argparse
import json
import os
import sys
import time
from collections import OrderedDict, defaultdict
from typing import Any, Dict, List, Tuple

from datasets import load_dataset, load_from_disk
from tqdm import tqdm

try:
    # Optional, only needed for answering
    from langchain_core.messages import HumanMessage, SystemMessage
    from langchain_openai import ChatOpenAI
except Exception:
    ChatOpenAI = None  # type: ignore
    HumanMessage = None  # type: ignore
    SystemMessage = None  # type: ignore


BASE_URL = "https://api.openai.com/v1"


def _load_any(path: str):
    """Load a HF dataset from a local dir or hub id, similar to loader."""
    if os.path.isdir(path):
        # Try load_from_disk directly
        try:
            return load_from_disk(path)
        except Exception:
            pass
        # Fallback: try inner dirs
        for name in os.listdir(path):
            subdir = os.path.join(path, name)
            if os.path.isdir(subdir):
                try:
                    return load_from_disk(path)
                except Exception:
                    continue
        # Last attempt
        return load_from_disk(path)
    return load_dataset(path)


def _ensure_tool_id_field(tools_dataset):
    """Ensure tools rows have an 'id' field (rename from _id or url)."""
    row0 = tools_dataset[0]
    if "id" in row0:
        return tools_dataset
    if "_id" in row0:
        return tools_dataset.rename_column("_id", "id")
    if "url" in row0:
        return tools_dataset.rename_column("url", "id")
    return tools_dataset


def _embedding_col(embedding_model: str) -> str:
    return f"embedding_{embedding_model}"


def _has_embedding(row: Dict[str, Any], embedding_model: str) -> bool:
    # New format: embedding_<model>
    ec = _embedding_col(embedding_model)
    if ec in row and row[ec] is not None:
        return True
    # Old format: embed + embed_model
    if "embed" in row and "embed_model" in row:
        return (row["embed_model"] == embedding_model) and (
            row["embed"] is not None
        )
    return False


def _filter_tools(
    tools_dataset, embedding_model: str
) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
    filtered = []
    id2idx: Dict[str, int] = {}
    for row in tools_dataset:
        if _has_embedding(row, embedding_model):
            idx = len(filtered)
            filtered.append(row)
            id2idx[row["id"]] = idx
    if not filtered:
        raise ValueError(
            f"No tools with embeddings for model '{embedding_model}'"
        )
    return filtered, id2idx


def _extract_correct_arms_multihop(
    query: Dict[str, Any], tool_id_to_index: Dict[str, int]
) -> List[int]:
    """Replicates MultihopQueryArmMapper.extract_correct_arms."""
    tool_id = query.get("evidence_list")
    if not tool_id:
        return []
    potential_tool_ids = []
    try:
        for _ in tool_id:
            if isinstance(_, dict) and "url" in _:
                potential_tool_ids.append(_["url"])
    except Exception:
        pass
    rslt: List[int] = []
    for t in potential_tool_ids:
        if t in tool_id_to_index:
            rslt.append(tool_id_to_index[t])
    return rslt


def _filter_queries(
    queries_dataset, embedding_model: str, tool_id_to_index: Dict[str, int]
) -> List[Dict[str, Any]]:
    valid: List[Dict[str, Any]] = []
    for q in queries_dataset:
        if not _has_embedding(q, embedding_model):
            continue
        correct_arms = _extract_correct_arms_multihop(q, tool_id_to_index)
        if correct_arms:
            valid.append(q)
    if not valid:
        raise ValueError("No valid queries found after filtering for multihop")
    return valid


def _sort_dataset_if_needed(dataset, text_field: str):
    try:
        return dataset.sort(text_field)
    except Exception:
        return dataset


def parse_rerank_jsonl(path: str) -> List[Dict[str, Any]]:
    records = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                # Expect keys: query_id, query_text, arm_id
                if all(k in obj for k in ("query_id", "query_text", "arm_id")):
                    records.append(obj)
            except Exception:
                continue
    if not records:
        raise ValueError(f"No valid records found in {path}")
    return records


def extract_main_question(query_text: str) -> str:
    # Expect: "Context: <main> | Focus: <sub>"
    marker = "Focus:"
    if marker in query_text:
        prefix = query_text.split(marker, 1)[0]
    else:
        prefix = query_text
    # Remove the leading "Context:" and separators if present
    prefix = prefix.replace("Context:", "").replace("|", " ")
    return prefix.strip()


def build_groups(records: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
    groups: Dict[str, Dict[str, Any]] = {}
    for r in records:
        main_q = extract_main_question(r["query_text"])  # grouping key
        g = groups.setdefault(
            main_q, {"query_ids": [], "arm_ids": [], "query_texts": []}
        )
        g["query_ids"].append(int(r["query_id"]))
        g["arm_ids"].append(int(r["arm_id"]))
        g["query_texts"].append(r["query_text"])
    return groups


def unique_preserve_order(seq: List[int]) -> List[int]:
    seen = set()
    out: List[int] = []
    for x in seq:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out


def build_llm(
    model: str = "gpt-4o-mini",
    temperature: float = 0.0,
    base_url: str = BASE_URL,
):
    if ChatOpenAI is None:
        raise RuntimeError(
            "LangChain ChatOpenAI is not installed in this environment."
        )
    return ChatOpenAI(model=model, temperature=temperature, base_url=base_url)


def answer_with_llm(llm, main_question: str, evidence_texts: List[str]) -> str:
    sys_prompt = (
        "You are a concise QA assistant. Given a main question and evidence documents, "
        "provide the final short answer only. If uncertain, provide your best effort."
    )
    docs_joined = "\n\n".join(
        f"--- Evidence {i+1} ---\n{t}" for i, t in enumerate(evidence_texts)
    )
    human_prompt = f"""
Main Question:
{main_question}

Evidence Documents:
{docs_joined}

Instruction: Provide the final answer only with no explanation.
"""
    messages = [
        SystemMessage(content=sys_prompt),
        HumanMessage(content=human_prompt),
    ]
    for _ in range(5):
        try:
            rsp = llm.invoke(messages)
            content = getattr(rsp, "content", None)
            if isinstance(content, str):
                return content.strip()
            # Some models may return a dict-like; fallback to string cast
            return str(content).strip()
        except Exception:
            time.sleep(1)
            continue
    return ""


def main():
    parser = argparse.ArgumentParser(
        description="Process MultiHop RAG rerank JSONL and compute QA accuracy."
    )
    parser.add_argument(
        "--rerank_jsonl",
        type=str,
        required=True,
        help="Path to rerank_info_*.jsonl",
    )
    parser.add_argument(
        "--tools_dataset_path",
        type=str,
        default="embeddings/mhr_tools_embedded",
    )
    parser.add_argument(
        "--queries_dataset_path",
        type=str,
        default="embeddings/mhr_decomposed_embedded",
    )
    parser.add_argument(
        "--embedding_model",
        type=str,
        default="large",
        choices=["ada", "small", "large"],
    )
    parser.add_argument("--output_path", type=str, default="")
    parser.add_argument(
        "--max_evidence",
        type=int,
        default=20,
        help="Max evidence docs per main question (to bound prompt size)",
    )
    parser.add_argument(
        "--model", type=str, default="gpt-4o", help="LLM model name"
    )
    parser.add_argument(
        "--base_url", type=str, default=BASE_URL, help="LLM base URL"
    )
    args = parser.parse_args()

    # Derive output path if not specified
    if not args.output_path:
        ts = time.strftime("%Y%m%d_%H%M%S")
        base = os.path.splitext(os.path.basename(args.rerank_jsonl))[0]
        args.output_path = f"runs/multihop_answers_{base}_{ts}.jsonl"
    os.makedirs(os.path.dirname(args.output_path) or ".", exist_ok=True)

    # 1) Load and parse rerank jsonl
    records = parse_rerank_jsonl(args.rerank_jsonl)

    # 2) Load datasets
    tools_ds = _load_any(args.tools_dataset_path)
    queries_ds = _load_any(args.queries_dataset_path)
    if isinstance(tools_ds, dict) and "train" in tools_ds:
        tools_ds = tools_ds["train"]
    if isinstance(queries_ds, dict) and "train" in queries_ds:
        queries_ds = queries_ds["train"]

    # 3) Sort like the loader (require_sort=True for multihop)
    tools_ds = _ensure_tool_id_field(tools_ds)
    tools_ds = _sort_dataset_if_needed(tools_ds, "body")
    queries_ds = _sort_dataset_if_needed(queries_ds, "formatted_query")

    # 4) Filter like the loader
    filtered_tools, tool_id_to_index = _filter_tools(
        tools_ds, args.embedding_model
    )
    valid_queries = _filter_queries(
        queries_ds, args.embedding_model, tool_id_to_index
    )

    # 5) Build quick accessors
    tool_texts = [row.get("body", "") for row in filtered_tools]
    # answers aligned to valid_queries index
    answers = [row.get("answer", "") for row in valid_queries]

    # 6) Group rerank records by main question
    groups = build_groups(records)

    # 7) Prepare LLM
    llm = build_llm(model=args.model, base_url=args.base_url)

    total = 0
    correct = 0

    with open(args.output_path, "w", encoding="utf-8") as outf:
        for main_q, info in tqdm(groups.items()):
            query_ids: List[int] = info["query_ids"]
            arm_ids: List[int] = unique_preserve_order(info["arm_ids"])
            if args.max_evidence > 0:
                arm_ids = arm_ids[: args.max_evidence]

            # Retrieve evidence texts from filtered tools by arm_id
            evidence_texts: List[str] = []
            for aid in arm_ids:
                if 0 <= aid < len(tool_texts):
                    evidence_texts.append(tool_texts[aid])

            # Ground truth: use answer at the first query_id index
            gt_answer = ""
            if query_ids:
                qidx = int(query_ids[0])
                if 0 <= qidx < len(answers):
                    gt_answer = str(answers[qidx])

            # LLM answer
            pred_answer = answer_with_llm(llm, main_q, evidence_texts)

            is_correct = (
                (pred_answer.strip().lower() == gt_answer.strip().lower())
                if gt_answer
                else False
            )

            out_obj = {
                "main_question": main_q,
                "sub_queries_count": len(info["query_texts"]),
                "sub_query_ids": query_ids,
                "selected_arm_ids": arm_ids,
                "ground_truth_answer": gt_answer,
                "predicted_answer": pred_answer,
                "correct": is_correct,
            }
            outf.write(json.dumps(out_obj, ensure_ascii=False) + "\n")

            total += 1
            correct += int(is_correct)

    acc = (correct / total) if total else 0.0
    print(
        f"Processed {total} main questions. Accuracy (exact match): {acc:.4f}"
    )
    print(f"Output written to: {args.output_path}")


if __name__ == "__main__":
    main()
