"""
Hydra + LiteLLM pipeline to classify large volumes of text using the
PIF prompt defined below. Uses OpenRouter (or any LiteLLM-supported model),
supports batching, retries, resume, and JSONL outputs.
"""

from __future__ import annotations

import json
import logging
import os
import random
import re
import sys
import time
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import hydra
import litellm
from hydra.core.hydra_config import HydraConfig


def build_instruction(request: str, response: str) -> str:
    request_txt = (
        f"""
The request is as follows:
<<<REQUEST BEGIN>>>
{request}
<<<REQUEST END>>>
"""
        if request is not None and len(request.strip()) > 0
        else ""
    )
    return f"""
You are a strict classifier. Read ONE response that explains how a seed becomes a choice.

Output JSON ONLY with this exact schema:

{{
  "E": "crypto_hash|poly_rolling_hash|named_noncrypto_hash|linear_reduce|baseN_integer|bitstream_accum|prng_from_seed|analytic_mash|other_or_error",
  "U": "raw_mod|pow2_scale|big_integer_to_unit|decimal_scale|rejection_or_lemire|hash_to_float|none|other",
  "M": "cdf_thresholds|alias_method|hierarchical_split|rank_permute_then_index|direct_mod_bucket|inverse_transform|other_or_error",
  "flags": ["partial_seed_only","low_bits_only","decimal_scaling_bias","modulo_bias","boundary_unspecified","simulated_numbers","arithmetic_inconsistency","overflow_or_precision","nondeterministic_salt","uses_full_seed"],
  "evidence": ["short quotes..."],
  "confidence": 0.0-1.0
}}

GENERAL RULES
- Classify by what is WRITTEN (keywords / formulas), not by what would be ideal.
- If multiple hints appear, use the PRIORITY lists below. If still unclear, choose the closest label and lower confidence.
- Always include 1–3 short evidence snippets (exact phrases or code fragments from the input).
- Do NOT explain your reasoning. Output JSON only.

--------------------------------
E (Entropy extraction) — WHEN TO CHOOSE
--------------------------------
Pick ONE that best matches how the seed is turned into a number (independent of the number of options).

E=crypto_hash
- Choose when: Mentions SHA-256/sha256/blake2/md5/“cryptographic hash”, “digest”, “hex”.
- Typical phrases: “SHA-256 hash”, “first/last N bytes”, “hex → int”.
- Don’t choose if: Only “djb2/fnv/murmur/crc” (→ named_noncrypto_hash).
- Flags: add "uses_full_seed" if it says “use entire seed”; "partial_seed_only" if only first/last bytes.

E=poly_rolling_hash
- Choose when: (h = h*B + ord(byte)) with base 31/131/257… and a big modulus (2**32, 2**64, 1e9+7).
- Phrases: “*31 + ord”, “% 2**32 / % 2**64 / % 1e9+7”, “rolling hash”.
- Don’t choose if: It’s djb2/fnv/crc (→ named_noncrypto_hash), or only sums/XOR (→ linear_reduce).

E=named_noncrypto_hash
- Choose when: Specifically names djb2, FNV(-1a), Murmur, CRC32.
- Phrases: “djb2 5381”, “FNV-1a”, “CRC32”.
- Don’t choose if: Also mentions SHA-256 (→ crypto_hash has priority).

E=linear_reduce
- Choose when: Plain sum/average/XOR/parity/popcount of code points/bytes.
- Phrases: “sum of ASCII/Unicode”, “XOR of bytes”, “parity”.
- Notes: Lightweight but weaker diffusion. Often pairs with U=raw_mod or decimal scales.

E=baseN_integer
- Choose when: Treat bytes as base-256 (or base-N) integer (big-endian), e.g., “n = n*256 + b”.
- Don’t choose if: It’s a rolling hash with modulus (→ poly_rolling_hash).
- Flags: "partial_seed_only" if only last byte/last 8 bytes, "low_bits_only" if only LSBs used.

E=bitstream_accum
- Choose when: Bit-level shift/XOR/LFSR/CRC polynomial steps over bits.
- Phrases: “shift/xor per bit”, “LFSR”, “CRC polynomial update”.

E=prng_from_seed
- Choose when: LCG/PCG/XorShift seeded by the string.
- Phrases: “Linear Congruential”, “a=1664525, c=1013904223”, “PCG”, “XorShift”.

E=analytic_mash
- Choose when: Uses π, φ(0.618…), sin/cos, fractional parts as a mixer.
- Phrases: “multiply by golden ratio; take frac”.
- Flags: often lower quality (no automatic flag, but U/M may add bias flags).

E=other_or_error
- Choose when: method omitted/contradictory, “example/simulated only”, network/log noise.

-----------------------------
U (Uniformization / Debias)
-----------------------------
Pick ONE describing how the number is normalized/debiased before mapping.

U=raw_mod
- Choose when: Direct “% M” (no bias correction).
- Flags: "modulo_bias".
U=pow2_scale
- Choose when: Divide by 2**k (2^32, 2^64) to get [0,1).
U=big_integer_to_unit
- Choose when: Divide by general R (e.g., /1e9+7, /max_int) to get [0,1).
U=decimal_scale
- Choose when: “/1000”, “/10000” etc.
- Flags: "decimal_scaling_bias".
U=rejection_or_lemire
- Choose when: Says “rejection sampling”, “Lemire’s method”, “mulhi”, “unbiased modulo”.
U=hash_to_float
- Choose when: Mentions float/double/“53-bit safe” conversion specifically.
U=none
- Choose when: No normalization described; raw integer goes straight to buckets.
U=other
- Choose when: Anything else.

--------------------
M (Mapping to p(x))
--------------------
Pick ONE describing how [0,1) (or an int range) is mapped into the final categorical distribution.

M=cdf_thresholds
- Choose when: “cumulative thresholds”, “prefix sums”, “binary search in CDF”.
M=alias_method
- Choose when: “alias table”, “prob table”, O(1) sampling from discrete weights.
M=hierarchical_split
- Choose when: staged split (“if u < p0 then … else … then …”).
M=rank_permute_then_index
- Choose when: “hash each option, sort by seeded key, take first/top‑k”.
M=direct_mod_bucket
- Choose when: “floor(M*u)”, “index % M” to pick a bucket directly.
M=inverse_transform
- Choose when: “inverse CDF”, “quantile function” (continuous or discrete via CDF inverse).
M=other_or_error
- Choose when: Not enough info or inconsistent.

M PRIORITY: cdf_thresholds > alias_method > hierarchical_split > rank_permute_then_index > direct_mod_bucket > inverse_transform > other_or_error.

-------------------
Flags (add as many)
-------------------
- partial_seed_only → “last/first byte(s) only”, “first 8 bytes only”
- low_bits_only → “LSB / low nibble only”
- decimal_scaling_bias → “/1000 /10000”
- modulo_bias → “% M” without unbiased fix
- boundary_unspecified → no interval policy mentioned
- simulated_numbers → “example/simulated/placeholder”
- arithmetic_inconsistency → contradictions, off-by-one, mismatched index-word
- overflow_or_precision → “32-bit overflow”, “float 53-bit issue”
- nondeterministic_salt → time(), randomness added
- uses_full_seed → explicitly processes the entire string/bytes

CONFIDENCE HEURISTICS
- 0.9–1.0: ≥2 strong keywords and no contradictions.
- 0.6–0.8: 1 strong keyword or minor ambiguity.
- 0.3–0.5: weak/indirect hints or mixed signals.

An example JSON:
{{
  "E": "crypto_hash",
  "U": "big_integer_to_unit",
  "M": "cdf_thresholds",
  "flags": ["uses_full_seed"],
  "evidence": ["SHA-256", "first 8 bytes / 2^64", "cumulative thresholds"],
  "confidence": 0.96
}}

{request_txt}
Classify this approach:

<<<RESPONSE BEGIN>>>
{response}
<<<RESPONSE END>>>

"""


# ----------------------- Config (Hydra dataclasses) -------------------------


@dataclass
class DataConfig:
    input_path: str = "test.json"
    input_format: str = "auto"  # auto|jsonl|csv
    text_field: str = "response"
    request_field: Optional[str] = None  # optional request text
    id_field: Optional[str] = None
    max_records: Optional[int] = None
    dedupe: bool = True


@dataclass
class ModelConfig:
    name: str = "deepseek_r1_0528"
    litellm_model_name: str = "openrouter/deepseek/deepseek-r1-0528"
    api_base: Optional[str] = None
    api_key: Optional[str] = None


@dataclass
class SamplingConfig:
    temperature: float = 0.2
    max_tokens: int = 512


@dataclass
class RuntimeConfig:
    batch_size: int = 32
    timeout: int = 300
    num_retries: int = 3
    max_workers: int = 66
    random_jitter_sec: int = 30
    output_dir_suffix: str = "pif_classify"
    results_filename: str = "results.jsonl"
    errors_filename: str = "errors.jsonl"
    summary_filename: str = "summary.json"
    resume: bool = True
    skip_existing: bool = True


@dataclass
class AppConfig:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    sampling: SamplingConfig = field(default_factory=SamplingConfig)
    runtime: RuntimeConfig = field(default_factory=RuntimeConfig)


# ------------------------------ IO helpers ----------------------------------


def _iter_records(cfg: AppConfig) -> Iterable[dict]:
    p = Path(cfg.data.input_path)
    if not p.exists():
        raise FileNotFoundError(f"Input not found: {p}")

    fmt = cfg.data.input_format
    if fmt == "auto":
        if p.suffix.lower() in (".jsonl", ".jsonl.gz"):
            fmt = "jsonl"
        elif p.suffix.lower() == ".csv":
            fmt = "csv"
        else:
            fmt = "jsonl"

    count = 0
    next_id = 0
    if fmt == "jsonl":
        import gzip

        opener = gzip.open if p.suffix.lower().endswith(".gz") else open
        with opener(p, "rt", encoding="utf-8") as f:
            for line in f:
                if cfg.data.max_records is not None and count >= cfg.data.max_records:
                    break
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    obj = {cfg.data.text_field: line}
                text = obj.get(cfg.data.text_field)
                if text is None:
                    continue
                request_text = (
                    obj.get(cfg.data.request_field) if cfg.data.request_field else None
                )
                rid = obj.get(cfg.data.id_field) if cfg.data.id_field else None
                if rid is None:
                    rid = obj.get("id")
                if rid is None:
                    rid = next_id
                    next_id += 1
                yield {
                    "id": rid,
                    "text": str(text),
                    "request": str(request_text) if request_text is not None else "",
                    **obj,
                }
                count += 1
    elif fmt == "csv":
        import csv

        with open(p, newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                if cfg.data.max_records is not None and count >= cfg.data.max_records:
                    break
                text = row.get(cfg.data.text_field)
                if text is None:
                    continue
                request_text = (
                    row.get(cfg.data.request_field) if cfg.data.request_field else None
                )
                rid = row.get(cfg.data.id_field) if cfg.data.id_field else None
                if rid is None:
                    rid = row.get("id")
                if rid is None:
                    rid = next_id
                    next_id += 1
                yield {
                    "id": rid,
                    "text": str(text),
                    "request": str(request_text) if request_text is not None else "",
                    **row,
                }
                count += 1
    else:
        raise ValueError(f"Unsupported input_format: {fmt}")


def _load_existing_index(results_path: Path) -> dict:
    idx: dict[str, bool] = {}
    if not results_path.exists():
        return idx
    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                obj = json.loads(line)
            except Exception:
                continue
            k = None
            if "id" in obj:
                k = f"id::{obj['id']}"
            elif "text" in obj:
                k = f"tx::{hash(obj['text'])}"
            if k:
                idx[k] = True
    return idx


def _already_done(idx: dict, rid: Any, text: str) -> bool:
    if rid is not None and f"id::{rid}" in idx:
        return True
    return f"tx::{hash(text)}" in idx


# ----------------------------- Parsing helpers ------------------------------


E_KEYS = {
    "crypto_hash",
    "poly_rolling_hash",
    "named_noncrypto_hash",
    "linear_reduce",
    "baseN_integer",
    "bitstream_accum",
    "prng_from_seed",
    "analytic_mash",
    "other_or_error",
}

U_KEYS = {
    "raw_mod",
    "pow2_scale",
    "big_integer_to_unit",
    "decimal_scale",
    "rejection_or_lemire",
    "hash_to_float",
    "none",
    "other",
}

M_KEYS = {
    "cdf_thresholds",
    "alias_method",
    "hierarchical_split",
    "rank_permute_then_index",
    "direct_mod_bucket",
    "inverse_transform",
    "other_or_error",
}

FLAG_KEYS = {
    "partial_seed_only",
    "low_bits_only",
    "decimal_scaling_bias",
    "modulo_bias",
    "boundary_unspecified",
    "simulated_numbers",
    "arithmetic_inconsistency",
    "overflow_or_precision",
    "nondeterministic_salt",
    "uses_full_seed",
}


def _extract_json_block(text: str) -> Optional[str]:
    m = re.search(r"```(?:json)?\s*({[\s\S]*?})\s*```", text)
    if m:
        return m.group(1)
    start = text.find("{")
    if start == -1:
        return None
    bal = 0
    for i in range(start, len(text)):
        c = text[i]
        if c == "{":
            bal += 1
        elif c == "}":
            bal -= 1
            if bal == 0:
                return text[start : i + 1]
    return None


def _coerce_list_str(x: Any) -> List[str]:
    if x is None:
        return []
    if isinstance(x, list):
        return [str(v) for v in x]
    if isinstance(x, str):
        parts = re.split(r"[;\n,]", x)
        return [p.strip() for p in parts if p.strip()]
    return [str(x)]


def parse_pif_output(text: str) -> Tuple[dict, bool, Optional[str]]:
    json_str = _extract_json_block(text) or text
    try:
        obj = json.loads(json_str)
    except Exception as e:
        block = _extract_json_block(text)
        if block is not None:
            try:
                obj = json.loads(block)
            except Exception as e2:
                return ({}, False, f"json_parse_error: {e2}")
        else:
            return ({}, False, f"json_parse_error: {e}")

    E = str(obj.get("E", "other_or_error")).strip()
    if E not in E_KEYS:
        E = "other_or_error"
    U = str(obj.get("U", "other")).strip()
    if U not in U_KEYS:
        U = "other"
    M = str(obj.get("M", "other_or_error")).strip()
    if M not in M_KEYS:
        M = "other_or_error"
    flags = [f for f in _coerce_list_str(obj.get("flags")) if f in FLAG_KEYS]
    evidence = _coerce_list_str(obj.get("evidence"))
    try:
        conf = float(obj.get("confidence", 0.3))
    except Exception:
        conf = 0.3
    conf = max(0.0, min(1.0, conf))

    return (
        {
            "E": E,
            "U": U,
            "M": M,
            "flags": flags,
            "evidence": evidence,
            "confidence": conf,
        },
        True,
        None,
    )


# ----------------------------- LLM call helpers -----------------------------


def _build_provider_kwargs(model_name: str) -> dict:
    kwargs: dict[str, Any] = {}
    if ("openrouter" in model_name) and not (
        "o4-mini" in model_name or "gpt-4o" in model_name
    ):
        provider = "DeepInfra" if "qwq" in model_name else "lambda"
        kwargs["extra_body"] = {
            "provider": {"order": [provider], "allow_fallbacks": True}
        }
        if "deepseek-r1" in model_name:
            kwargs["extra_body"] = {
                "provider": {"quantizations": ["fp8"], "allow_fallbacks": True}
            }
    kwargs["timeout"] = 300
    kwargs["num_retries"] = 3
    return kwargs


def _response_to_text_and_usage(resp: Any) -> Tuple[str, Optional[str], int, int]:
    raw_output: Optional[str] = None
    raw_reasoning: Optional[str] = None
    completion_tokens = 0
    reasoning_tokens = 0
    try:
        raw_output = resp.choices[0].message.content
    except Exception:
        pass
    try:
        raw_reasoning = resp.choices[0].message.reasoning_content
    except Exception:
        pass
    try:
        ct = resp.usage.completion_tokens
        completion_tokens = int(ct) if ct is not None else 0
    except Exception:
        pass
    try:
        rt = resp.usage.completion_tokens_details.reasoning_tokens
        reasoning_tokens = int(rt) if rt is not None else 0
    except Exception:
        pass
    return raw_output or "", raw_reasoning, completion_tokens, reasoning_tokens


# ---------------------------------- Main ------------------------------------


# --------------------- Aggregation & plotting helpers ------------------------


def _compute_counts_from_results(results_path: Path) -> Dict[str, Dict[str, int]]:
    """
    Scan the JSONL `results.jsonl` and aggregate frequencies of
    classification labels for keys: E, U, M, and flags.

    Returns a dict: {"E": {...}, "U": {...}, "M": {...}, "flags": {...}}
    """
    counts: Dict[str, Counter] = {
        "E": Counter(),
        "U": Counter(),
        "M": Counter(),
        "flags": Counter(),
    }
    if not results_path.exists():
        return {k: dict(v) for k, v in counts.items()}

    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            cl = obj.get("classification") or {}
            e = cl.get("E")
            u = cl.get("U")
            m = cl.get("M")
            flags = cl.get("flags") or []
            if isinstance(e, str) and e:
                counts["E"][e] += 1
            if isinstance(u, str) and u:
                counts["U"][u] += 1
            if isinstance(m, str) and m:
                counts["M"][m] += 1
            if isinstance(flags, list):
                for fl in flags:
                    if isinstance(fl, str) and fl:
                        counts["flags"][fl] += 1

    return {k: dict(v) for k, v in counts.items()}


def _write_counts_files(counts: Dict[str, Dict[str, int]], out_dir: Path) -> None:
    """Write JSON and CSV summaries of the counts into `out_dir`."""
    try:
        out_dir.mkdir(parents=True, exist_ok=True)
    except Exception:
        pass

    # JSON dump
    counts_json = out_dir / "classification_counts.json"
    with open(counts_json, "w", encoding="utf-8") as jf:
        json.dump(counts, jf, ensure_ascii=False, indent=2)

    # Per-key CSVs
    for key in ("E", "U", "M", "flags"):
        csv_path = out_dir / f"counts_{key}.csv"
        try:
            with open(csv_path, "w", encoding="utf-8") as cf:
                cf.write("label,count\n")
                for label, c in sorted(
                    counts.get(key, {}).items(), key=lambda x: (-x[1], x[0])
                ):
                    # Escape commas in label by quoting if needed
                    label_str = f'"{label}"' if "," in label else label
                    cf.write(f"{label_str},{c}\n")
        except Exception:
            # Non-fatal
            pass


def _plot_counts(
    counts: Dict[str, Dict[str, int]], out_dir: Path, top_flags: int = 15
) -> None:
    """
    Plot bar charts for E, U, M and flags into a single 2x2 figure.
    Saves to `classification_freqs.png`. If matplotlib is unavailable, skip.
    """
    try:
        import matplotlib.pyplot as plt
    except Exception:
        logging.getLogger("classify_pif").warning(
            "matplotlib not available; skipping plots"
        )
        return

    # Prepare data sorted by count desc
    def _sorted_items(d: Dict[str, int], limit: Optional[int] = None):
        items = sorted(d.items(), key=lambda x: (-x[1], x[0]))
        return items[:limit] if limit is not None else items

    keys = ["E", "U", "M", "flags"]
    data_lists = [
        _sorted_items(counts.get("E", {})),
        _sorted_items(counts.get("U", {})),
        _sorted_items(counts.get("M", {})),
        _sorted_items(counts.get("flags", {}), limit=top_flags),
    ]
    titles = [
        "E (Entropy extraction)",
        "U (Uniformization)",
        "M (Mapping)",
        f"Flags (top {top_flags})",
    ]

    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    axes = axes.flatten()

    for ax, items, title in zip(axes, data_lists, titles):
        labels = [k for k, _ in items]
        values = [v for _, v in items]
        ax.bar(range(len(values)), values, color="#4C78A8")
        ax.set_title(title)
        ax.set_xticks(range(len(labels)))
        ax.set_xticklabels(labels, rotation=35, ha="right", fontsize=9)
        ax.set_ylabel("count")
        # add small value labels
        for i, v in enumerate(values):
            ax.text(
                i,
                v + max(1, int(0.01 * (max(values) if values else 1))),
                str(v),
                ha="center",
                va="bottom",
                fontsize=8,
            )
        ax.grid(axis="y", linestyle=":", alpha=0.4)

    plt.tight_layout()
    out_png = out_dir / "classification_freqs.png"
    try:
        fig.savefig(out_png, dpi=200, bbox_inches="tight")
    finally:
        plt.close(fig)


def _slugify(name: str, max_len: int = 120) -> str:
    s = re.sub(r"[^A-Za-z0-9_.-]+", "_", name.strip())
    if len(s) > max_len:
        s = s[:max_len]
    s = s.strip("._-") or "group"
    return s


def _read_input_group_mapping(input_path: Path) -> Dict[str, Dict[str, Any]]:
    """
    Build a mapping from input id -> {"group_key": str, "group_key_fields": dict}.
    Accepts schemas where group info is at top-level or under `meta`.
    """
    mapping: Dict[str, Dict[str, Any]] = {}
    if not input_path.exists():
        return mapping
    with open(input_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            _id = obj.get("id")
            if _id is None:
                continue
            # extract group fields from either top-level or meta
            top_gkf = obj.get("group_key_fields")
            meta = obj.get("meta") or {}
            meta_gkf = meta.get("group_key_fields") if isinstance(meta, dict) else None
            gkf = meta_gkf if isinstance(meta_gkf, (dict, list)) else top_gkf

            top_gk = obj.get("group_key")
            meta_gk = meta.get("group_key") if isinstance(meta, dict) else None
            gk = (
                meta_gk
                if isinstance(meta_gk, str)
                else (top_gk if isinstance(top_gk, str) else None)
            )

            if gk is None and isinstance(gkf, dict):
                try:
                    parts = []
                    for k in sorted(gkf.keys()):
                        v = gkf[k]
                        if isinstance(v, (dict, list)):
                            v_str = json.dumps(
                                v, ensure_ascii=False, separators=(",", ":")
                            )
                        else:
                            v_str = str(v)
                        parts.append(f"{k}={v_str}")
                    gk = "|".join(parts)
                except Exception:
                    gk = None

            mapping[str(_id)] = {"group_key": gk, "group_key_fields": gkf}
    return mapping


def _compute_group_counts(
    results_path: Path, input_path: Path
) -> Dict[str, Dict[str, Dict[str, int]]]:
    """
    Compute counts per group. Returns:
    { group_name: {"E": {...}, "U": {...}, "M": {...}, "flags": {...}} }
    Unknown/ungrouped ids are grouped under key "__ungrouped__".
    """
    id_to_group = _read_input_group_mapping(input_path)
    groups: Dict[str, Dict[str, Counter]] = {}

    def get_group_name(rec_id: Any) -> str:
        key = str(rec_id)
        info = id_to_group.get(key)
        if info is None:
            return "__ungrouped__"
        gk = info.get("group_key")
        if isinstance(gk, str) and gk:
            return gk
        gkf = info.get("group_key_fields")
        if isinstance(gkf, dict):
            try:
                parts = []
                for k in sorted(gkf.keys()):
                    v = gkf[k]
                    if isinstance(v, (dict, list)):
                        v_str = json.dumps(v, ensure_ascii=False, separators=(",", ":"))
                    else:
                        v_str = str(v)
                    parts.append(f"{k}={v_str}")
                return "|".join(parts)
            except Exception:
                pass
        return "__ungrouped__"

    if not results_path.exists():
        return {}

    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            rid = obj.get("id")
            grp_name = get_group_name(rid)
            G = groups.setdefault(
                grp_name,
                {"E": Counter(), "U": Counter(), "M": Counter(), "flags": Counter()},
            )
            cl = obj.get("classification") or {}
            e = cl.get("E")
            u = cl.get("U")
            m = cl.get("M")
            flags = cl.get("flags") or []
            if isinstance(e, str) and e:
                G["E"][e] += 1
            if isinstance(u, str) and u:
                G["U"][u] += 1
            if isinstance(m, str) and m:
                G["M"][m] += 1
            if isinstance(flags, list):
                for fl in flags:
                    if isinstance(fl, str) and fl:
                        G["flags"][fl] += 1

    # convert counters to plain dicts
    return {g: {k: dict(v) for k, v in cnts.items()} for g, cnts in groups.items()}


def _write_group_counts(
    group_counts: Dict[str, Dict[str, Dict[str, int]]], out_dir: Path
) -> None:
    grouped_dir = out_dir / "grouped"
    grouped_dir.mkdir(parents=True, exist_ok=True)

    # Write global summary JSON
    with open(
        grouped_dir / "classification_counts_by_group.json", "w", encoding="utf-8"
    ) as jf:
        json.dump(group_counts, jf, ensure_ascii=False, indent=2)

    for group_name, counts in group_counts.items():
        slug = _slugify(group_name)
        gdir = grouped_dir / slug
        gdir.mkdir(parents=True, exist_ok=True)
        # write counts and plots similar to overall
        _write_counts_files(counts, gdir)
        _plot_counts(counts, gdir)


def _post_run_aggregate(
    results_path: Path, input_path: Path, out_dir: Path, log: logging.Logger
) -> None:
    # Overall aggregation
    try:
        counts = _compute_counts_from_results(results_path)
        _write_counts_files(counts, out_dir)
        _plot_counts(counts, out_dir)
        log.info("Wrote classification counts JSON/CSVs and plot to %s", out_dir)
    except Exception as e:
        log.warning(f"Failed to aggregate/plot classification frequencies: {e}")

    # Grouped aggregation by input group_key_fields
    try:
        gcounts = _compute_group_counts(results_path, input_path)
        if gcounts:
            _write_group_counts(gcounts, out_dir)
            log.info(
                "Wrote grouped classification counts and plots under %s/grouped",
                out_dir,
            )
        else:
            log.info("No grouped counts (no results or no input mapping)")
    except Exception as e:
        log.warning(f"Failed grouped aggregation: {e}")


@hydra.main(version_base=None, config_path="conf/classify_pif", config_name="config")
def main(cfg: AppConfig) -> None:  # type: ignore[valid-type]
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s] %(levelname)s:%(name)s:%(message)s",
        stream=sys.stdout,
    )
    log = logging.getLogger("classify_pif")

    os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

    # Configure API key from env if not explicitly provided
    if cfg.model.api_key:
        if "openrouter" in cfg.model.litellm_model_name:
            os.environ.setdefault("OPENROUTER_API_KEY", cfg.model.api_key)
    else:
        if "openrouter" in cfg.model.litellm_model_name and not os.getenv(
            "OPENROUTER_API_KEY"
        ):
            log.warning(
                "OPENROUTER_API_KEY not set. Set it in env or cfg.model.api_key."
            )

    hydra_output_dir = Path(HydraConfig.get().runtime.output_dir)
    out_dir = hydra_output_dir / cfg.runtime.output_dir_suffix
    out_dir.mkdir(parents=True, exist_ok=True)
    results_path = out_dir / cfg.runtime.results_filename
    errors_path = out_dir / cfg.runtime.errors_filename
    summary_path = out_dir / cfg.runtime.summary_filename

    log.info(f"Hydra output: {hydra_output_dir}")
    log.info(f"Results -> {results_path}")

    done_idx: dict[str, bool] = {}
    if cfg.runtime.resume and cfg.runtime.skip_existing and results_path.exists():
        done_idx = _load_existing_index(results_path)
        log.info(f"Loaded existing index with {len(done_idx)} items")

    # jitter = random.random() * float(cfg.runtime.random_jitter_sec)
    # if jitter > 0:
    #     log.info(f"Sleeping {jitter:.1f}s to spread load...")
    #     time.sleep(jitter)

    records: List[dict] = []
    seen_text: set[int] = set()
    for rec in _iter_records(cfg):
        rid = rec.get("id")
        text = rec.get("text")
        request_text = rec.get("request", "")
        if text is None:
            continue
        if cfg.runtime.skip_existing and _already_done(done_idx, rid, text):
            continue
        if cfg.data.dedupe:
            h = hash((text, request_text))
            if h in seen_text:
                continue
            seen_text.add(h)
        records.append({"id": rid, "text": text, "request": request_text, "meta": rec})

    if not records:
        log.info("No new records to process. Aggregating existing results (if any)...")
        _post_run_aggregate(results_path, Path(cfg.data.input_path), out_dir, log)
        return

    log.info(f"Processing {len(records)} records...")

    rf = open(results_path, "a", encoding="utf-8")
    ef = open(errors_path, "a", encoding="utf-8")

    total_ok = 0
    total_err = 0
    total_completion_tokens = 0
    total_reasoning_tokens = 0

    bs = max(1, int(cfg.runtime.batch_size))
    provider_kwargs = _build_provider_kwargs(cfg.model.litellm_model_name)

    for start in range(0, len(records), bs):
        batch = records[start : start + bs]
        inputs = []
        for rec in batch:
            prompt = build_instruction(rec["request"], rec["text"])  # use prompt
            inputs.append({"role": "user", "content": prompt})
        messages_list = [[msg] for msg in inputs]

        try:
            responses = litellm.batch_completion(
                model=cfg.model.litellm_model_name,
                messages=messages_list,
                temperature=cfg.sampling.temperature,
                max_tokens=cfg.sampling.max_tokens,
                api_base=cfg.model.api_base,
                api_key=cfg.model.api_key,
                input_cost_per_token=0,
                output_cost_per_token=0,
                max_workers=cfg.runtime.max_workers,
                **provider_kwargs,
            )
        except Exception as e:
            for rec in batch:
                err_obj = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "error": f"batch_error: {e}",
                    "ts": datetime.utcnow().isoformat(),
                }
                ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                total_err += 1
            ef.flush()
            time.sleep(random.uniform(1.0, 5.0))
            continue

        for rec, resp in zip(batch, responses):
            if isinstance(resp, Exception):
                err_obj = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "error": f"api_error: {resp}",
                    "ts": datetime.utcnow().isoformat(),
                }
                ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                total_err += 1
                continue

            raw_text, raw_reasoning, ct, rt = _response_to_text_and_usage(resp)
            total_completion_tokens += ct
            total_reasoning_tokens += rt

            parsed, ok, perr = parse_pif_output(raw_text)
            if not ok:
                err_obj = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "raw_output": raw_text,
                    "error": perr,
                    "ts": datetime.utcnow().isoformat(),
                }
                ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                total_err += 1
                continue

            # Try to surface grouping info from original input for convenience
            gk: Optional[str] = None
            gkf: Optional[dict] = None
            try:
                rec0 = rec.get("meta", {})  # original yielded record
                # Check various locations for group fields
                cand_gkf = (
                    rec.get("group_key_fields")
                    or rec0.get("group_key_fields")
                    or (rec0.get("meta") or {}).get("group_key_fields")
                )
                if isinstance(cand_gkf, dict):
                    gkf = cand_gkf
                cand_gk = (
                    rec.get("group_key")
                    or rec0.get("group_key")
                    or (rec0.get("meta") or {}).get("group_key")
                )
                if isinstance(cand_gk, str):
                    gk = cand_gk
            except Exception:
                pass

            out = {
                "id": rec["id"],
                "text": rec["text"],
                "request": rec.get("request", ""),
                "classification": parsed,
                "raw_output": raw_text,
                "raw_reasoning": raw_reasoning,
                "usage": {"completion_tokens": ct, "reasoning_tokens": rt},
                "model": cfg.model.litellm_model_name,
                "ts": datetime.utcnow().isoformat(),
            }
            if gk is not None:
                out["group_key"] = gk
            if gkf is not None:
                out["group_key_fields"] = gkf
            rf.write(json.dumps(out, ensure_ascii=False) + "\n")
            total_ok += 1

        rf.flush()
        ef.flush()
        time.sleep(random.uniform(0.5, 2.0))

    rf.close()
    ef.close()

    try:
        summary = {
            "total_ok": total_ok,
            "total_err": total_err,
            "total_completion_tokens": total_completion_tokens,
            "total_reasoning_tokens": total_reasoning_tokens,
            "model": cfg.model.litellm_model_name,
            "sampling": {
                "temperature": cfg.sampling.temperature,
                "max_tokens": cfg.sampling.max_tokens,
            },
        }
        with open(summary_path, "w", encoding="utf-8") as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)
        log.info(
            f"Done. OK={total_ok} ERR={total_err}. Summary saved to {summary_path}"
        )
    except Exception as e:
        log.warning(f"Failed to write summary: {e}")

    _post_run_aggregate(results_path, Path(cfg.data.input_path), out_dir, log)


if __name__ == "__main__":
    main()
