"""
Collect raw LLM outputs from RSP probability experiment artifacts into a JSONL
file suitable as input for scripts/classify_pif.py.

It follows the same discovery/parsing logic as scripts/plot_divergence_vs_n.py:
- Read base run directories from a paths file (e.g., paths_n_vs_div.txt)
- Prefer result JSONs (parameters + raw_outputs) if requested; otherwise use logs
- Parse config (words/probabilities/method/model/temperature) and raw outputs

Supports subsampling by max-per-run, sampling rate, take-every-k, and max-total.

Usage examples:
  python scripts/make_pif_input.py \
    --paths-file paths_n_vs_div.txt \
    --out outputs/pif_inputs.jsonl \
    --max-per-run 500 --sample-rate 0.5 --shuffle --max-total 20000

Output JSONL fields per line:
- id: unique id (run leaf + sample index)
- response: raw model output (full text)
- request: reconstructed task description from config (optional; may be empty)
- meta: {
    model, method, temperature, words, probabilities,
    run_log|results_json,
    group_key: string "model=...|method=...|temp=...|words=[...]",
    group_key_fields: {model, method, temperature, words}
  }
"""

from __future__ import annotations

import argparse
import json
import random
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional, Tuple

from pcot.parameters import get_default_temperature

# ---------------- I/O helpers (similar to plot_divergence_vs_n) --------------


def read_paths_file(paths_file: Path) -> List[Path]:
    paths: List[Path] = []
    for line in paths_file.read_text().splitlines():
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        p = Path(line)
        paths.append((paths_file.parent / line).resolve() if not p.is_absolute() else p)
    return paths


def discover_run_logs(base_dirs: Iterable[Path]) -> List[Path]:
    logs: List[Path] = []
    for base in base_dirs:
        if not base.exists():
            continue
        for p in base.rglob("rsp_probability_exp.log"):
            if p.is_file():
                logs.append(p)
    return sorted(logs)


def discover_run_jsons(base_dirs: Iterable[Path]) -> List[Path]:
    """Find results JSON files under `results` directories."""
    jsons: List[Path] = []
    for base in base_dirs:
        if not base.exists():
            continue
        for results_dir in (p for p in base.rglob("results") if p.is_dir()):
            found = sorted(results_dir.glob("results__*.json"))
            if not found:
                found = sorted(results_dir.glob("*.json"))
            jsons.extend(found)
    return jsons


# ---------------- Parsing helpers (trimmed from plot_divergence_vs_n) --------

_LOG_LINE_PREFIX = re.compile(r"^\[\d{4}-\d{2}-\d{2} ")
_RAW_OUT_RE = re.compile(r"Raw LLM Output \(Sample (\d+)\):\s*(.*)$")
# Same <answer> extractor as in plot_divergence_vs_n.py (case-insensitive, last/innermost)
AnswerRegex = re.compile(
    r"(?is)<answer\b[^>]*>((?:(?!<answer\b).)*?)</answer\b>(?!.*</answer\b>)",
    flags=re.DOTALL | re.IGNORECASE,
)


def _parse_config_from_log_block(
    yaml_block: str,
) -> Tuple[
    Optional[str], List[str], List[str], Optional[float], Optional[str], Optional[str]
]:
    """Minimal YAML-like parser for the printed Hydra config block.

    Returns (prompt_type, words, probabilities, temperature, litellm_model_name, model_name)
    """
    prompt_type: Optional[str] = None
    words: List[str] = []
    probs: List[str] = []
    temperature: Optional[float] = None
    litellm_model_name: Optional[str] = None
    model_name: Optional[str] = None

    in_prompt = False
    in_words = False
    in_probs = False
    in_sampling = False
    in_model = False

    for line in yaml_block.splitlines():
        if line and not line.startswith(" "):
            in_prompt = False
            in_words = False
            in_probs = False
            in_sampling = False
            in_model = False
            if line.strip().startswith("prompt:"):
                in_prompt = True
            elif line.strip().startswith("sampling:"):
                in_sampling = True
            elif line.strip().startswith("model:"):
                in_model = True
            continue

        s = line.strip()
        if in_prompt:
            if s.startswith("type:"):
                prompt_type = s.split(":", 1)[1].strip() or None
            elif s.startswith("words:"):
                in_words = True
                in_probs = False
            elif s.startswith("probabilities:"):
                in_probs = True
                in_words = False
            elif in_words and s.startswith("- "):
                word = s[2:].strip()
                if word == "'false'":
                    word = "false"
                elif word == "'yes'":
                    word = "yes"
                elif word == "'true'":
                    word = "true"
                words.append(word)
            elif in_probs and s.startswith("- "):
                probs.append(s[2:].strip())
        elif in_sampling:
            if s.startswith("temperature:"):
                val = s.split(":", 1)[1].strip()
                if val.lower() in ("null", "none", ""):
                    temperature = None
                else:
                    try:
                        temperature = float(val)
                    except Exception:
                        temperature = None
        elif in_model:
            if s.startswith("litellm_model_name:"):
                litellm_model_name = s.split(":", 1)[1].strip()
            elif s.startswith("name:"):
                model_name = s.split(":", 1)[1].strip()

    return prompt_type, words, probs, temperature, litellm_model_name, model_name


def parse_log_file(log_path: Path) -> Tuple[Optional[dict], List[str]]:
    """Parse a single experiment log file to extract configuration and raw LLM outputs.

    Returns (config_dict, raw_outputs)
    - config_dict contains keys: 'prompt_type', 'words', 'probabilities', 'temperature'.
      Any of these may be None or empty if not found.
    - raw_outputs is an ordered list of the LLM raw output strings.
    """
    text = log_path.read_text(errors="ignore")
    lines = text.splitlines()

    # 1) Extract YAML block after "Loaded configuration:" (first occurrence)
    config_dict: Optional[dict] = None
    yaml_start = None
    for i, ln in enumerate(lines):
        if "Loaded configuration:" in ln:
            yaml_start = i + 1
            break
    if yaml_start is not None:
        yaml_lines: List[str] = []
        for j in range(yaml_start, len(lines)):
            ln = lines[j]
            if _LOG_LINE_PREFIX.match(ln):
                break
            yaml_lines.append(ln)
        (
            prompt_type,
            words,
            probs,
            temperature,
            litellm_model_name,
            model_name,
        ) = _parse_config_from_log_block("\n".join(yaml_lines))
        # Default temperature if missing/null and litellm_model_name available
        if temperature is None and litellm_model_name:
            try:
                temperature = float(get_default_temperature(litellm_model_name))
            except Exception:
                pass
        config_dict = {
            "prompt_type": prompt_type,
            "words": words,
            "probabilities": probs,
            "temperature": temperature,
            "litellm_model_name": litellm_model_name,
            "model_name": model_name,
        }

    # 2) Extract Raw LLM outputs
    raw_outputs: List[str] = []
    i = 0
    n = len(lines)
    while i < n:
        m = _RAW_OUT_RE.search(lines[i])
        if not m:
            i += 1
            continue
        # Remainder on the same line after ':'
        first = m.group(2)
        content_lines: List[str] = [first] if first else []
        j = i + 1
        while j < n and not _LOG_LINE_PREFIX.match(lines[j]):
            content_lines.append(lines[j])
            j += 1
        content = "\n".join(content_lines).strip()
        if content:
            raw_outputs.append(content)
        i = j

    return config_dict, raw_outputs


# --------------------- Utility to build request text -------------------------


def build_request_from_config(cfg: Optional[dict]) -> str:
    if not cfg:
        return ""
    words = cfg.get("words") or []
    probs = cfg.get("probabilities") or []
    if not words or not probs:
        return ""
    ws = ", ".join([f'"{w}"' for w in words])
    ps = ", ".join([str(p) for p in probs])
    n = len(words)
    return f"Choose between {ws}. Select one of these {n} options with probabilities: {ps}."


def _extract_answer(raw_text: str) -> Optional[str]:
    """Extract final <answer> content as lower-case, or None if not present."""
    try:
        ans = AnswerRegex.findall(raw_text)[0].strip().lower()
        return ans
    except Exception:
        return None


# ------------------------ Group key helpers (like plot script) ---------------


def normalize_model_name(raw: Optional[str]) -> Optional[str]:
    if raw is None:
        return None
    m = raw.strip()
    mapping = {
        "deepseek_r1": "deepseek-r1",
        "deepseek_r1_0528": "deepseek-r1",
        "deepseek-r1-0528": "deepseek-r1",
        "deepseek-v3": "deepseek-v3",
        "deepseek_v3": "deepseek-v3",
        "gpt-4o-2024-08-06": "gpt-4o",
        "o4mini_high": "o4-mini-high",
        "qwq32B": "QwQ-32B",
    }
    return mapping.get(m, m.replace("_", "-"))


@dataclass(frozen=True)
class GroupKey:
    model: str
    method: str
    temperature: Optional[float]
    words: tuple[str, ...]


def build_group_key(cfg: Optional[dict]) -> GroupKey:
    model_raw = (cfg or {}).get("model_name")
    model = normalize_model_name(model_raw) or (model_raw or "unknown")
    method = (cfg or {}).get("prompt_type") or "unknown"
    temperature_val = (cfg or {}).get("temperature")
    try:
        temperature = float(temperature_val) if temperature_val is not None else None
    except Exception:
        temperature = None
    words_list = list((cfg or {}).get("words") or [])
    return GroupKey(
        model=model, method=method, temperature=temperature, words=tuple(words_list)
    )


def group_key_to_string(key: GroupKey) -> str:
    w = ",".join([str(x) for x in key.words])
    return f"model={key.model}|method={key.method}|temp={key.temperature}|words=[{w}]"


# ------------------------------- Main ---------------------------------------


def main() -> None:
    ap = argparse.ArgumentParser(
        description="Collect raw_outputs (or raw_reasonings) into JSONL for PIF classification"
    )
    ap.add_argument("--paths-file", type=Path, default=Path("paths_n_vs_div.txt"))
    ap.add_argument("--out", type=Path, default=Path("outputs/pif_inputs.jsonl"))
    ap.add_argument("--max-per-run", type=int, default=None)
    ap.add_argument(
        "--sample-rate",
        type=float,
        default=None,
        help="0<r<=1; keep this fraction from each run",
    )
    ap.add_argument(
        "--every", type=int, default=None, help="Keep every k-th sample from each run"
    )
    ap.add_argument("--max-total", type=int, default=None)
    ap.add_argument("--shuffle", action="store_true")
    ap.add_argument(
        "--include-json-results",
        action="store_true",
        help="Also read raw_outputs from result JSON files",
    )
    ap.add_argument(
        "--use-raw-reasonings",
        action="store_true",
        help=(
            "Read responses from 'raw_reasonings' in result JSONs. "
            "Skips null/empty entries and ignores logs."
        ),
    )
    # Optional filters (default: keep all)
    ap.add_argument("--model-contains", type=str, default=None)
    ap.add_argument("--method", type=str, default=None)
    ap.add_argument("--temperature", type=float, default=None)
    # Parsing controls
    ap.add_argument(
        "--require-answer",
        action="store_true",
        default=True,
        help="Keep only samples with <answer>…</answer> present",
    )
    ap.add_argument(
        "--require-in-words",
        action="store_true",
        default=True,
        help="Require parsed <answer> to be one of config words",
    )
    ap.add_argument(
        "--print-stats", action="store_true", help="Print simple counts per phase"
    )
    args = ap.parse_args()

    base_dirs = read_paths_file(args.paths_file)
    # If using reasonings, we must source from JSONs; logs are not used in this mode.
    logs = [] if args.use_raw_reasonings else discover_run_logs(base_dirs)
    jsons = (
        discover_run_jsons(base_dirs)
        if (args.include_json_results or args.use_raw_reasonings)
        else []
    )
    if not logs and not jsons:
        print(f"No logs or result JSONs found from {args.paths_file}")
        return

    out_path = args.out
    out_path.parent.mkdir(parents=True, exist_ok=True)
    cnt_total = 0
    with out_path.open("w", encoding="utf-8") as wf:
        # Phase 1: collect valid answers grouped by GroupKey
        groups: dict[GroupKey, list[dict]] = {}
        json_run_leaves: set[str] = set()

        if args.include_json_results or args.use_raw_reasonings:
            for js_path in jsons:
                try:
                    data = json.loads(js_path.read_text())
                except Exception:
                    continue
                if args.use_raw_reasonings:
                    # Prefer raw_reasonings; filter out null/empty entries.
                    raw_list = list(data.get("raw_reasonings") or [])
                    raws = [r for r in raw_list if isinstance(r, str) and r.strip()]
                else:
                    raws = list(data.get("raw_outputs") or [])
                if not raws:
                    continue
                params = data.get("parameters", {})
                cfg = {
                    "prompt_type": params.get("prompt", {}).get("type"),
                    "words": params.get("prompt", {}).get("words") or [],
                    "probabilities": params.get("prompt", {}).get("probabilities")
                    or [],
                    "temperature": params.get("sampling", {}).get("temperature"),
                    "litellm_model_name": params.get("model", {}).get(
                        "litellm_model_name"
                    ),
                    "model_name": params.get("model", {}).get("name"),
                }
                gkey = build_group_key(cfg)
                if args.model_contains and args.model_contains not in gkey.model:
                    continue
                if (
                    args.temperature is not None
                    and gkey.temperature != args.temperature
                ):
                    continue
                if args.method is not None and gkey.method != args.method:
                    continue
                if len(gkey.words) >= 128:
                    continue
                words_lc = [str(w).lower() for w in (cfg.get("words") or [])]
                if not words_lc:
                    continue
                # req = build_request_from_config(cfg)
                run_leaf = str(js_path.parent.parent.parent)
                json_run_leaves.add(run_leaf)
                for i, raw in enumerate(raws):
                    ans = _extract_answer(raw)
                    if args.require_answer and ans is None:
                        continue
                    if (
                        args.require_in_words
                        and ans is not None
                        and ans not in words_lc
                    ):
                        continue
                    rec = {
                        "id": f"{run_leaf.split('/')[-1]}:{i}",
                        "response": raw,
                        # "request": req,
                        "meta": {
                            "model": cfg.get("model_name"),
                            "method": cfg.get("prompt_type"),
                            "temperature": cfg.get("temperature"),
                            "words": cfg.get("words"),
                            "probabilities": cfg.get("probabilities"),
                            "results_json": str(js_path),
                            "source_field": "raw_reasonings" if args.use_raw_reasonings else "raw_outputs",
                            "group_key": group_key_to_string(gkey),
                            "group_key_fields": {
                                "model": gkey.model,
                                "method": gkey.method,
                                "temperature": gkey.temperature,
                                "words": list(gkey.words),
                            },
                            "parsed_answer": ans,
                        },
                    }
                    groups.setdefault(gkey, []).append(rec)

        # Fallback to logs for runs without result JSONs (skipped in reasonings mode)
        if not args.use_raw_reasonings:
            for log_path in logs:
                run_leaf = str(log_path.parent.parent)
                if (args.include_json_results or args.use_raw_reasonings) and run_leaf in json_run_leaves:
                    continue
                cfg, raws = parse_log_file(log_path)
                if not raws:
                    continue
                gkey = build_group_key(cfg)
                if args.model_contains and args.model_contains not in gkey.model:
                    continue
                if args.temperature is not None and gkey.temperature != args.temperature:
                    continue
                if args.method is not None and gkey.method != args.method:
                    continue
                if len(gkey.words) >= 128:
                    continue
                words_lc = [str(w).lower() for w in ((cfg or {}).get("words") or [])]
                if not words_lc:
                    continue
                # req = build_request_from_config(cfg)
                for i, raw in enumerate(raws):
                    ans = _extract_answer(raw)
                    if args.require_answer and ans is None:
                        continue
                    if args.require_in_words and ans is not None and ans not in words_lc:
                        continue
                    rec = {
                        "id": f"{run_leaf.split('/')[-1]}:{i}",
                        "response": raw,
                        # "request": req,
                        "meta": {
                            "model": (cfg or {}).get("model_name"),
                            "method": (cfg or {}).get("prompt_type"),
                            "temperature": (cfg or {}).get("temperature"),
                            "words": (cfg or {}).get("words"),
                            "probabilities": (cfg or {}).get("probabilities"),
                            "run_log": str(log_path),
                            "group_key": group_key_to_string(gkey),
                            "group_key_fields": {
                                "model": gkey.model,
                                "method": gkey.method,
                                "temperature": gkey.temperature,
                                "words": list(gkey.words),
                            },
                            "parsed_answer": ans,
                        },
                    }
                    groups.setdefault(gkey, []).append(rec)

        if args.print_stats:
            total_groups = sum(len(v) for v in groups.values())
            print(f"Collected groups: {len(groups)} records: {total_groups}")

        # Phase 2: per-key sampling and write out
        for gkey, records in groups.items():
            print(
                len(gkey.words),
                gkey.method,
                gkey.temperature,
                gkey.model,
                cnt_total,
                len(records),
            )
            idxs = list(range(len(records)))
            random.shuffle(idxs)
            if args.every and args.every > 1:
                idxs = idxs[:: args.every]
            if args.sample_rate is not None:
                r = max(0.0, min(1.0, float(args.sample_rate)))
                if r <= 0:
                    idxs = []
                else:
                    idxs = [i for i in idxs if random.random() < r]
            if args.max_per_run is not None:
                idxs = idxs[: args.max_per_run]

            for i in idxs:
                wf.write(json.dumps(records[i], ensure_ascii=False) + "\n")
                cnt_total += 1
                if args.max_total is not None and cnt_total >= args.max_total:
                    break
            if args.max_total is not None and cnt_total >= args.max_total:
                break

    print(f"Wrote {cnt_total} examples to {out_path}")


if __name__ == "__main__":
    main()
