"""
Hydra + LiteLLM pipeline to classify responses into novelty axes using the
prompt below. Reads JSONL, calls OpenRouter via LiteLLM, parses JSON, writes
results, aggregates counts, and plots frequencies.
"""

from __future__ import annotations

import json
import logging
import os
import random
import re
import sys
import time
from collections import Counter
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import hydra
import litellm
from hydra.core.hydra_config import HydraConfig


def classify_prompt(request, response):
    prompt = f"""
You are an evaluator. Read TARGET TEXT and assign exactly ONE category per axis.
Respond with JSON ONLY (no extra text). Keep each "reason" ≤ 10 words.
Base decisions ONLY on explicit or strongly implied evidence in TARGET TEXT.

AXES AND “WHEN TO PICK” RULES

- Mapping  (how the seed is turned into decisions)
  ["SUMMOD","ROLL","CHUNK","STRIDE","HASH","PRNG","WEIGHTED"]
  • SUMMOD — If it adds codes/length and takes %N to index.  
  • ROLL — If it updates a running value (h=h*B+ord, etc.).  
  • CHUNK — If different contiguous seed blocks drive different parts.  
  • STRIDE — If it uses every k‑th / interleaved characters per part.  
  • HASH — If it names a hash (SHA/MD5/CRC) or “digest”.  
  • PRNG — If it seeds a random generator (LCG/XorShift/PCG/random()).  
  • WEIGHTED — If it computes weights/probabilities and selects by them.

- Generation  (how content is built)
  ["SLOT","ATTR","RULE","LIST","PROC","PLAN"]
  • SLOT — If it fills a fixed template with slots/placeholders.  
  • ATTR — If it composes orthogonal attributes (e.g., Domain×Feature×Style).  
  • RULE — If it enforces a formal scheme (schema/5‑7‑5/regex).  
  • LIST — If it selects items from a predefined set of options.  
  • PROC — If it procedurally synthesizes items by rules (phonology/grammar).  
  • PLAN — If it outlines first, then writes from that plan.

- Pool  (where candidate options come from)
  ["STATIC","GENERATED","HYBRID"]
  • STATIC — If it uses a fixed, hand‑made list declared in the text.  
  • GENERATED — If the list/options are created on the fly in the text.  
  • HYBRID — If it clearly mixes fixed and on‑the‑fly options.

- SeedInterpretation  (how the seed is treated)
  ["COMP_MAP","INTERP_MAP","BOTH"]
  • COMP_MAP — If characters are treated as numbers deterministically.  
  • INTERP_MAP — If characters are treated symbolically/metaphorically.  
  • BOTH — If both numeric and symbolic interpretations are used.

- DiversityTarget  (what varies across seeds)
  ["STRUCTURAL","CONTENT","BOTH","NONE"]
  • STRUCTURAL — If form/outline/rhyme/section count change by seed.  
  • CONTENT — If concrete items/names/details change while form stays.  
  • BOTH — If both structure and content vary with the seed.  
  • NONE — If no seed‑driven variation is apparent.

- SeedUsage  (how much of the seed actually drives output)
  ["NONE","PART","FULL","SAT"]
  • NONE — If seed is mentioned but not used.  
  • PART — If only a small subset (e.g., first chars) affects output.  
  • FULL — If most parts of the seed influence main decisions.  
  • SAT — If the seed is overused (repetition/forced echoes).

- MappingExposure  (how openly the mapping is described)
  ["HIDE","TRACE","EXPL","TOOL"]
  • HIDE — If it claims to use a seed but gives no method.  
  • TRACE — If it hints at “string→number” without specifics.  
  • EXPL — If it states the exact formula/steps or pseudo‑code.  
  • TOOL — If it hands mapping to an external tool/script.

- RandomnessScope  (where randomness is applied)
  ["GLOBAL","LOCAL","HYBRID"]
  • GLOBAL — If one global draw sets theme/style and that’s it.  
  • LOCAL — If separate draws per attribute/line/token are used.  
  • HYBRID — If both a global choice and local per‑part draws appear.

- PlannerType  (how the response is organized)
  ["DIRECT","PLAN","PROG","AGENT"]
  • DIRECT — If it writes immediately with no planning.  
  • PLAN — If it lists steps/outline before writing.  
  • PROG — If it uses pseudo‑code/grammar/PCFG/rules to generate.  
  • AGENT — If roles are split (planner vs writer) within the text.

- RandomMechanism  (how choices are selected)
  ["ARGMAX","SCORE","SAMPLE","REJECT"]
  • ARGMAX — If it deterministically picks the top option (no sampling).  
  • SCORE — If it scores/weights options but takes the best deterministically.  
  • SAMPLE — If it samples from a distribution (e.g., top‑k/alias/CDF).  
  • REJECT — If it resamples until constraints are satisfied.

- ConstraintHandling  (how constraints are met)
  ["STRICT","HEUR","POST","IGNORE"]
  • STRICT — If it explicitly validates counts/length/dedup/schema.  
  • HEUR — If it applies heuristics to likely satisfy constraints.  
  • POST — If it fixes violations after generating (self‑check/corrections).  
  • IGNORE — If it states constraints but doesn’t enforce them.

- SeedLocality  (where the seed is consumed)
  ["STRUCT","CONTENT","BOTH","META"]
  • STRUCT — If seed drives structure/ordering/rhyme/sectioning.  
  • CONTENT — If seed drives lexical choices/entities/numbers.  
  • BOTH — If it drives both structure and content.  
  • META — If seed mainly affects the explanation/meta commentary.

- Determinism  (reproducibility policy as stated)
  ["DET","QUASI","NONDET"]
  • DET — If same seed ⇒ same output is promised.  
  • QUASI — If same seed gives same outline but details may drift.  
  • NONDET — If behavior is not reproducible or contradicts claims.

    - Coverage  (how much of the seed is said to be used)
      ["LOW","MID","HIGH"]
      • LOW — If one small part drives decisions (<30%).  
      • MID — If multiple parts but not most (30–70%).  
      • HIGH — If most parts are consumed (>70%) or per‑part use is clear.

FLAGS  (add zero or more when clearly warranted)
["STRUCT","CHK","COH","FIX","SAFE","NONUSE","LEAK","META","HACK","TOOLCALL","DETFAIL","OVERFIT","CREATIVE","SEEDANCHOR","DECOMP"]
• STRUCT: Explicit structural formatting present.  
• CHK: Formal checks/dedup/counting executed.  
• COH: Consistency of chosen attributes is maintained.  
• FIX: Post‑hoc correction step exists.  
• SAFE: Safety/ethics sanitization applied.  
• NONUSE: Seed not actually used.  
• LEAK: Hidden reasoning/keys/forbidden info revealed.  
• META: Method explanation dominates the answer.  
• HACK: Boilerplate “ASCII sum % N” style template.  
• TOOLCALL: External tool/API is assumed or invoked.  
• DETFAIL: Claims determinism but behavior contradicts it.  
• OVERFIT: Repetitive vocabulary or templated phrasing.
• CREATIVE: Output is a creative narrative/story/poem.
• SEEDANCHOR: Story ties motifs/themes/elements to seed parts.
• DECOMP: Story is decomposed into elements and seeded per element.

TIE‑BREAKING
- Prefer more specific evidence: EXPL > TRACE > HIDE (for exposure).
- Mapping specificity order: HASH > PRNG > ROLL > CHUNK > STRIDE > WEIGHTED > SUMMOD.
- If two categories still fit, pick the one most emphasized by the text.


OUTPUT FORMAT (return JSON exactly in this shape; FLAGS can be empty [])
{{
  "mapping": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "generation": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "pool": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "seedInterpretation": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "diversityTarget": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "seedUsage": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "mappingExposure": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "randomnessScope": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "plannerType": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "randomMechanism": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "constraintHandling": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "seedLocality": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "determinism": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "coverage": {{"reason":"<≤10 words>","choice":"<ID>"}},
  "flags": ["<FLAG1>","<FLAG2>"]
}}

REQUEST
<<<
{request}
>>>

TARGET TEXT
<<<
{response}
>>>
"""
    return prompt.strip()


# ----------------------- Config (Hydra dataclasses) -------------------------


@dataclass
class DataConfig:
    input_path: str = "test.json"
    input_format: str = "auto"  # auto|jsonl|csv
    text_field: str = "response"
    request_field: Optional[str] = "request"
    id_field: Optional[str] = "id"
    max_records: Optional[int] = None
    dedupe: bool = True


@dataclass
class ModelConfig:
    name: str = "gpt-5-mini"
    litellm_model_name: str = "openrouter/openai/gpt-5-mini"
    api_base: Optional[str] = None
    api_key: Optional[str] = None


@dataclass
class SamplingConfig:
    temperature: float = 0.0
    max_tokens: int = 600


@dataclass
class RuntimeConfig:
    batch_size: int = 1000
    timeout: int = 300
    num_retries: int = 3
    max_workers: int = 100
    random_jitter_sec: int = 30
    output_dir_suffix: str = "novelty_classify"
    results_filename: str = "results.jsonl"
    errors_filename: str = "errors.jsonl"
    summary_filename: str = "summary.json"
    resume: bool = True
    skip_existing: bool = True


@dataclass
class AppConfig:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    sampling: SamplingConfig = field(default_factory=SamplingConfig)
    runtime: RuntimeConfig = field(default_factory=RuntimeConfig)


# ------------------------------ IO helpers ----------------------------------


def _iter_records(cfg: AppConfig) -> Iterable[dict]:
    p = Path(cfg.data.input_path)
    if not p.exists():
        raise FileNotFoundError(f"Input not found: {p}")

    fmt = cfg.data.input_format
    if fmt == "auto":
        if p.suffix.lower() in (".jsonl", ".jsonl.gz"):
            fmt = "jsonl"
        elif p.suffix.lower() == ".csv":
            fmt = "csv"
        else:
            fmt = "jsonl"

    count = 0
    next_id = 0
    if fmt == "jsonl":
        import gzip

        opener = gzip.open if p.suffix.lower().endswith(".gz") else open
        with opener(p, "rt", encoding="utf-8") as f:
            for line in f:
                if cfg.data.max_records is not None and count >= cfg.data.max_records:
                    break
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    obj = {cfg.data.text_field: line}
                text = obj.get(cfg.data.text_field)
                if text is None:
                    continue
                request_text = (
                    obj.get(cfg.data.request_field) if cfg.data.request_field else None
                )
                rid = obj.get(cfg.data.id_field) if cfg.data.id_field else None
                if rid is None:
                    rid = obj.get("id")
                if rid is None:
                    rid = next_id
                    next_id += 1
                yield {
                    "id": rid,
                    "text": str(text),
                    "request": str(request_text) if request_text is not None else "",
                    **obj,
                }
                count += 1
    elif fmt == "csv":
        import csv

        with open(p, newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                if cfg.data.max_records is not None and count >= cfg.data.max_records:
                    break
                text = row.get(cfg.data.text_field)
                if text is None:
                    continue
                request_text = (
                    row.get(cfg.data.request_field) if cfg.data.request_field else None
                )
                rid = row.get(cfg.data.id_field) if cfg.data.id_field else None
                if rid is None:
                    rid = row.get("id")
                if rid is None:
                    rid = next_id
                    next_id += 1
                yield {
                    "id": rid,
                    "text": str(text),
                    "request": str(request_text) if request_text is not None else "",
                    **row,
                }
                count += 1
    else:
        raise ValueError(f"Unsupported input_format: {fmt}")


def _load_existing_index(results_path: Path) -> dict:
    idx: dict[str, bool] = {}
    if not results_path.exists():
        return idx
    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                obj = json.loads(line)
            except Exception:
                continue
            k = None
            if "id" in obj:
                k = f"id::{obj['id']}"
            elif "text" in obj:
                k = f"tx::{hash(obj['text'])}"
            if k:
                idx[k] = True
    return idx


def _already_done(idx: dict, rid: Any, text: str) -> bool:
    if rid is not None and f"id::{rid}" in idx:
        return True
    return f"tx::{hash(text)}" in idx


# ----------------------------- Parsing helpers ------------------------------


AXIS_KEYS = [
    "mapping",
    "generation",
    "pool",
    "seedInterpretation",
    "diversityTarget",
    "seedUsage",
    "mappingExposure",
    "randomnessScope",
    "plannerType",
    "randomMechanism",
    "constraintHandling",
    "seedLocality",
    "determinism",
    "coverage",
]

# Allowed choices per axis, aligned with classify_prompt()
CHOICE_SETS = {
    "mapping": {
        "SUMMOD",
        "ROLL",
        "CHUNK",
        "STRIDE",
        "HASH",
        "PRNG",
        "WEIGHTED",
    },
    "generation": {"SLOT", "ATTR", "RULE", "LIST", "PROC", "PLAN"},
    "pool": {"STATIC", "GENERATED", "HYBRID"},
    "seedInterpretation": {"COMP_MAP", "INTERP_MAP", "BOTH"},
    "diversityTarget": {"STRUCTURAL", "CONTENT", "BOTH", "NONE"},
    "seedUsage": {"NONE", "PART", "FULL", "SAT"},
    "mappingExposure": {"HIDE", "TRACE", "EXPL", "TOOL"},
    "randomnessScope": {"GLOBAL", "LOCAL", "HYBRID"},
    "plannerType": {"DIRECT", "PLAN", "PROG", "AGENT"},
    "randomMechanism": {"ARGMAX", "SCORE", "SAMPLE", "REJECT"},
    "constraintHandling": {"STRICT", "HEUR", "POST", "IGNORE"},
    "seedLocality": {"STRUCT", "CONTENT", "BOTH", "META"},
    "determinism": {"DET", "QUASI", "NONDET"},
    "coverage": {"LOW", "MID", "HIGH"},
}


def _extract_json_block(text: str) -> Optional[str]:
    m = re.search(r"```(?:json)?\s*({[\s\S]*?})\s*```", text)
    if m:
        return m.group(1)
    start = text.find("{")
    if start == -1:
        return None
    bal = 0
    for i in range(start, len(text)):
        c = text[i]
        if c == "{":
            bal += 1
        elif c == "}":
            bal -= 1
            if bal == 0:
                return text[start : i + 1]
    return None


def _norm_axis_key(k: str) -> str:
    k0 = k.strip()
    # allow case-insensitive and some aliases
    aliases = {
        "seedinterpretation": "seedInterpretation",
        "seed_interpretation": "seedInterpretation",
        "diversitytarget": "diversityTarget",
        "diversity_target": "diversityTarget",
    }
    low = k0.lower()
    if low in aliases:
        return aliases[low]
    for ax in AXIS_KEYS:
        if low == ax.lower():
            return ax
    return k0


def parse_novelty_output(text: str) -> Tuple[dict, bool, Optional[str]]:
    json_str = _extract_json_block(text) or text
    try:
        obj = json.loads(json_str)
    except Exception as e:
        block = _extract_json_block(text)
        if block is not None:
            try:
                obj = json.loads(block)
            except Exception as e2:
                return ({}, False, f"json_parse_error: {e2}")
        else:
            return ({}, False, f"json_parse_error: {e}")

    out: Dict[str, Dict[str, Any]] = {}
    for raw_key, val in obj.items():
        k = _norm_axis_key(raw_key)
        if k not in AXIS_KEYS and k not in {"flags", "confidence"}:
            continue
        reason = None
        choice = None
        if isinstance(val, dict):
            reason = val.get("reason")
            choice = val.get("choice")
        elif isinstance(val, str):
            choice = val
        elif k == "flags":
            # Preserve flags as a simple list of allowed tokens
            allowed_flags = {
                # Core documented flags
                "STRUCT",
                "CHK",
                "COH",
                "FIX",
                "SAFE",
                "NONUSE",
                "LEAK",
                "META",
                "HACK",
                "TOOLCALL",
                "DETFAIL",
                "OVERFIT",
                # Creative writing / seed-anchoring
                "CREATIVE",
                "SEEDANCHOR",
                "DECOMP",
            }
            flags_val: List[str] = []
            if isinstance(val, list):
                for it in val:
                    if isinstance(it, str):
                        t = it.strip().upper()
                        if t in allowed_flags and t not in flags_val:
                            flags_val.append(t)
            elif isinstance(val, str):
                t = val.strip().upper()
                if t in allowed_flags:
                    flags_val = [t]
            out["flags"] = {"reason": "", "choice": flags_val}
            continue
        elif k == "confidence":
            try:
                conf = float(val)
            except Exception:
                conf = None
            if conf is not None:
                out["confidence"] = {"reason": "", "choice": conf}
            continue
        # normalize choice
        if isinstance(choice, str):
            c = choice.strip().upper()
        else:
            c = None
        valid = CHOICE_SETS.get(k, set())
        if c not in valid:
            # attempt loose match (remove non-letters)
            if isinstance(choice, str):
                c2 = re.sub(r"[^A-Z]", "", choice.upper())
                cand = None
                for v in valid:
                    if re.sub(r"[^A-Z]", "", v.upper()) == c2:
                        cand = v
                        break
                c = cand if cand is not None else None
        if c is None:
            try:
                c = sorted(valid)[0] if valid else None
            except Exception:
                c = next(iter(valid)) if valid else None
        out[k] = {"reason": str(reason) if reason is not None else "", "choice": c}

    # ensure all axes are present
    for ax in AXIS_KEYS:
        if ax not in out:
            out[ax] = {"reason": "", "choice": next(iter(CHOICE_SETS[ax]))}

    return out, True, None


# ----------------------------- LLM call helpers -----------------------------


def _build_provider_kwargs(model_name: str) -> dict:
    kwargs: dict[str, Any] = {}
    if ("openrouter" in model_name) and not (
        "o4-mini" in model_name or "gpt-4o" in model_name
    ):
        # Be conservative with provider hints; let OpenRouter route Google/Gemini
        if "deepseek-r1" in model_name:
            kwargs["extra_body"] = {
                "provider": {"quantizations": ["fp8"], "allow_fallbacks": True}
            }
        elif "qwq" in model_name:
            kwargs["extra_body"] = {
                "provider": {"order": ["DeepInfra"], "allow_fallbacks": True}
            }
        elif not ("google" in model_name or "gemini" in model_name):
            # For most non-Google models, Lambda is a good default
            kwargs["extra_body"] = {
                "provider": {"order": ["lambda"], "allow_fallbacks": True}
            }
    kwargs["timeout"] = 300
    kwargs["num_retries"] = 3
    return kwargs


def _response_to_text_and_usage(resp: Any) -> Tuple[str, Optional[str], int, int]:
    raw_output: Optional[str] = None
    raw_reasoning: Optional[str] = None
    completion_tokens = 0
    reasoning_tokens = 0
    try:
        raw_output = resp.choices[0].message.content
    except Exception:
        pass
    try:
        raw_reasoning = resp.choices[0].message.reasoning_content
    except Exception:
        pass
    try:
        ct = resp.usage.completion_tokens
        completion_tokens = int(ct) if ct is not None else 0
    except Exception:
        pass
    try:
        rt = resp.usage.completion_tokens_details.reasoning_tokens
        reasoning_tokens = int(rt) if rt is not None else 0
    except Exception:
        pass
    return raw_output or "", raw_reasoning, completion_tokens, reasoning_tokens


# --------------------- Aggregation & plotting helpers ------------------------


def _compute_counts_from_results(results_path: Path) -> Dict[str, Dict[str, int]]:
    counts: Dict[str, Counter] = {ax: Counter() for ax in AXIS_KEYS}
    if not results_path.exists():
        return {k: dict(v) for k, v in counts.items()}
    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            cl = obj.get("classification") or {}
            for ax in AXIS_KEYS:
                try:
                    choice = cl.get(ax, {}).get("choice")
                    if isinstance(choice, str) and choice:
                        counts[ax][choice] += 1
                except Exception:
                    pass
    return {k: dict(v) for k, v in counts.items()}


def _write_counts_files(counts: Dict[str, Dict[str, int]], out_dir: Path) -> None:
    try:
        out_dir.mkdir(parents=True, exist_ok=True)
    except Exception:
        pass
    # JSON dump
    counts_json = out_dir / "classification_counts.json"
    with open(counts_json, "w", encoding="utf-8") as jf:
        json.dump(counts, jf, ensure_ascii=False, indent=2)
    # Per-axis CSVs
    for ax in AXIS_KEYS:
        csv_path = out_dir / f"counts_{ax}.csv"
        try:
            with open(csv_path, "w", encoding="utf-8") as cf:
                cf.write("label,count\n")
                for label, c in sorted(
                    counts.get(ax, {}).items(), key=lambda x: (-x[1], x[0])
                ):
                    label_str = f'"{label}"' if "," in label else label
                    cf.write(f"{label_str},{c}\n")
        except Exception:
            pass


def _plot_counts(counts: Dict[str, Dict[str, int]], out_dir: Path) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception:
        logging.getLogger("classify_novelty").warning(
            "matplotlib not available; skipping plots"
        )
        return

    axes_list = AXIS_KEYS
    # Dynamic subplot grid based on number of axes
    n = len(axes_list)
    cols = 4
    rows = (n + cols - 1) // cols
    fig, axs = plt.subplots(rows, cols, figsize=(4 * cols, 3.5 * rows))
    flat_axes = axs.flatten() if hasattr(axs, "flatten") else [axs]

    def _sorted_items(d: Dict[str, int]):
        return sorted(d.items(), key=lambda x: (-x[1], x[0]))

    titles = {
        "mapping": "Mapping",
        "generation": "Generation",
        "pool": "Pool",
        "seedInterpretation": "SeedInterpretation",
        "diversityTarget": "DiversityTarget",
        "seedUsage": "SeedUsage",
        "mappingExposure": "MappingExposure",
        "randomnessScope": "RandomnessScope",
        "plannerType": "PlannerType",
        "randomMechanism": "RandomMechanism",
        "constraintHandling": "ConstraintHandling",
        "seedLocality": "SeedLocality",
        "determinism": "Determinism",
        "coverage": "Coverage",
    }

    for i, ax_name in enumerate(axes_list):
        ax = flat_axes[i]
        items = _sorted_items(counts.get(ax_name, {}))
        labels = [k for k, _ in items]
        values = [v for _, v in items]
        ax.bar(range(len(values)), values, color="#4C78A8")
        ax.set_title(titles.get(ax_name, ax_name))
        ax.set_xticks(range(len(labels)))
        ax.set_xticklabels(labels, rotation=35, ha="right", fontsize=9)
        ax.set_ylabel("count")
        for j, v in enumerate(values):
            if v > 0:
                ax.text(
                    j,
                    v + max(1, int(0.02 * max(values) if values else 1)),
                    str(v),
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )
        ax.grid(axis="y", linestyle=":", alpha=0.4)

    # Hide any extra subplot
    for k in range(len(axes_list), rows * cols):
        try:
            flat_axes[k].axis("off")
        except Exception:
            pass

    plt.tight_layout()
    out_png = out_dir / "classification_freqs.png"
    try:
        fig.savefig(out_png, dpi=200, bbox_inches="tight")
    finally:
        plt.close(fig)


def _extract_flags(cl: Dict[str, Any]) -> List[str]:
    try:
        f = cl.get("flags")
        if not isinstance(f, dict):
            return []
        choices = f.get("choice")
        if isinstance(choices, str):
            choices = [choices]
        if not isinstance(choices, list):
            return []
        out: List[str] = []
        for it in choices:
            if isinstance(it, str) and it:
                s = it.strip().upper()
                if s and s not in out:
                    out.append(s)
        return out
    except Exception:
        return []


def _compute_flag_counts_from_results(results_path: Path) -> Dict[str, int]:
    counts: Counter = Counter()
    if not results_path.exists():
        return {}
    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            cl = obj.get("classification") or {}
            flags = _extract_flags(cl)
            for fl in set(flags):
                counts[fl] += 1
    return dict(counts)


def _write_flag_counts_files(counts: Dict[str, int], out_dir: Path) -> None:
    try:
        out_dir.mkdir(parents=True, exist_ok=True)
    except Exception:
        pass
    # JSON dump
    counts_json = out_dir / "flag_counts.json"
    with open(counts_json, "w", encoding="utf-8") as jf:
        json.dump(counts, jf, ensure_ascii=False, indent=2)
    # CSV dump
    csv_path = out_dir / "counts_flags.csv"
    try:
        with open(csv_path, "w", encoding="utf-8") as cf:
            cf.write("label,count\n")
            for label, c in sorted(counts.items(), key=lambda x: (-x[1], x[0])):
                label_str = f'"{label}"' if "," in label else label
                cf.write(f"{label_str},{c}\n")
    except Exception:
        pass


def _plot_flag_counts(counts: Dict[str, int], out_dir: Path) -> None:
    if not counts:
        return
    try:
        import matplotlib.pyplot as plt
    except Exception:
        logging.getLogger("classify_novelty").warning(
            "matplotlib not available; skipping flag plots"
        )
        return
    items = sorted(counts.items(), key=lambda x: (-x[1], x[0]))
    labels = [k for k, _ in items]
    values = [v for _, v in items]
    fig, ax = plt.subplots(figsize=(max(6, 0.5 * len(labels)), 4))
    ax.bar(range(len(values)), values, color="#72B7B2")
    ax.set_title("Flags")
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=35, ha="right", fontsize=9)
    ax.set_ylabel("count")
    for j, v in enumerate(values):
        if v > 0:
            ax.text(
                j,
                v + max(1, int(0.02 * max(values) if values else 1)),
                str(v),
                ha="center",
                va="bottom",
                fontsize=8,
            )
    ax.grid(axis="y", linestyle=":", alpha=0.4)
    plt.tight_layout()
    out_png = out_dir / "flag_freqs.png"
    try:
        fig.savefig(out_png, dpi=200, bbox_inches="tight")
    finally:
        plt.close(fig)


def _compute_group_flag_counts(
    results_path: Path, input_path: Path
) -> Dict[str, Dict[str, int]]:
    id_to_group = _read_input_group_mapping(input_path)
    groups: Dict[str, Counter] = {}

    def get_group_name(rec_id: Any) -> str:
        key = str(rec_id)
        info = id_to_group.get(key)
        if info is None:
            return "__ungrouped__"
        gk = info.get("group_key")
        if isinstance(gk, str) and gk:
            return gk
        return "__ungrouped__"

    if not results_path.exists():
        return {}

    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            rid = obj.get("id")
            grp_name = get_group_name(rid)
            cl = obj.get("classification") or {}
            flags = _extract_flags(cl)
            if not flags:
                continue
            G = groups.setdefault(grp_name, Counter())
            for fl in set(flags):
                G[fl] += 1

    return {g: dict(cnts) for g, cnts in groups.items()}


def _write_group_flag_counts(
    group_counts: Dict[str, Dict[str, int]], out_dir: Path
) -> None:
    grouped_dir = out_dir / "grouped"
    grouped_dir.mkdir(parents=True, exist_ok=True)

    with open(grouped_dir / "flag_counts_by_group.json", "w", encoding="utf-8") as jf:
        json.dump(group_counts, jf, ensure_ascii=False, indent=2)

    for group_name, counts in group_counts.items():
        slug = _slugify(group_name)
        gdir = grouped_dir / slug
        gdir.mkdir(parents=True, exist_ok=True)
        _write_flag_counts_files(counts, gdir)
        _plot_flag_counts(counts, gdir)


def _slugify(name: str, max_len: int = 120) -> str:
    s = re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name).strip())
    if len(s) > max_len:
        s = s[:max_len]
    s = s.strip("._-") or "group"
    return s


def _read_input_group_mapping(input_path: Path) -> Dict[str, Dict[str, Any]]:
    """
    For novelty inputs, use 'category' field if present to build a grouping.
    Returns: {id_str: {"group_key": <category>, "group_key_fields": {"category": <category>}}}
    """
    mapping: Dict[str, Dict[str, Any]] = {}
    if not input_path.exists():
        return mapping
    with open(input_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            _id = obj.get("id")
            if _id is None:
                continue
            cat = obj.get("category")
            gk = str(cat) if cat is not None else None
            gkf = {"category": cat} if cat is not None else None
            mapping[str(_id)] = {"group_key": gk, "group_key_fields": gkf}
    return mapping


def _compute_group_counts(
    results_path: Path, input_path: Path
) -> Dict[str, Dict[str, Dict[str, int]]]:
    id_to_group = _read_input_group_mapping(input_path)
    groups: Dict[str, Dict[str, Counter]] = {}

    def get_group_name(rec_id: Any) -> str:
        key = str(rec_id)
        info = id_to_group.get(key)
        if info is None:
            return "__ungrouped__"
        gk = info.get("group_key")
        if isinstance(gk, str) and gk:
            return gk
        return "__ungrouped__"

    if not results_path.exists():
        return {}

    with open(results_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            rid = obj.get("id")
            grp_name = get_group_name(rid)
            G = groups.setdefault(grp_name, {ax: Counter() for ax in AXIS_KEYS})
            cl = obj.get("classification") or {}
            for ax in AXIS_KEYS:
                try:
                    choice = cl.get(ax, {}).get("choice")
                    if isinstance(choice, str) and choice:
                        G[ax][choice] += 1
                except Exception:
                    pass

    return {g: {k: dict(v) for k, v in cnts.items()} for g, cnts in groups.items()}


def _write_group_counts(
    group_counts: Dict[str, Dict[str, Dict[str, int]]], out_dir: Path
) -> None:
    grouped_dir = out_dir / "grouped"
    grouped_dir.mkdir(parents=True, exist_ok=True)

    with open(
        grouped_dir / "classification_counts_by_group.json", "w", encoding="utf-8"
    ) as jf:
        json.dump(group_counts, jf, ensure_ascii=False, indent=2)

    for group_name, counts in group_counts.items():
        slug = _slugify(group_name)
        gdir = grouped_dir / slug
        gdir.mkdir(parents=True, exist_ok=True)
        _write_counts_files(counts, gdir)
        _plot_counts(counts, gdir)


# ---------------------------------- Main ------------------------------------


@hydra.main(
    version_base=None, config_path="conf/classify_novelty", config_name="config"
)
def main(cfg: AppConfig) -> None:  # type: ignore[valid-type]
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s] %(levelname)s:%(name)s:%(message)s",
        stream=sys.stdout,
    )
    log = logging.getLogger("classify_novelty")

    os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

    # Configure API key from env if not explicitly provided
    # if cfg.model.api_key:
    #     pass
    #     # if "openrouter" in cfg.model.litellm_model_name:
    #     #     os.environ.setdefault("OPENROUTER_API_KEY", cfg.model.api_key)
    # else:
    #     if "openrouter" in cfg.model.litellm_model_name and not os.getenv(
    #         "OPENROUTER_API_KEY"
    #     ):
    #         log.warning(
    #             "OPENROUTER_API_KEY not set. Set it in env or cfg.model.api_key."
    #         )

    hydra_output_dir = Path(HydraConfig.get().runtime.output_dir)
    out_dir = hydra_output_dir / cfg.runtime.output_dir_suffix
    out_dir.mkdir(parents=True, exist_ok=True)
    results_path = out_dir / cfg.runtime.results_filename
    errors_path = out_dir / cfg.runtime.errors_filename
    summary_path = out_dir / cfg.runtime.summary_filename

    log.info(f"Hydra output: {hydra_output_dir}")
    log.info(f"Results -> {results_path}")

    done_idx: dict[str, bool] = {}
    if cfg.runtime.resume and cfg.runtime.skip_existing and results_path.exists():
        done_idx = _load_existing_index(results_path)
        log.info(f"Loaded existing index with {len(done_idx)} items")

    records: List[dict] = []
    seen_text: set[int] = set()
    for rec in _iter_records(cfg):
        rid = rec.get("id")
        text = rec.get("text")
        request_text = rec.get("request", "")
        if text is None:
            continue
        if cfg.runtime.skip_existing and _already_done(done_idx, rid, text):
            continue
        if cfg.data.dedupe:
            h = hash((text, request_text))
            if h in seen_text:
                continue
            seen_text.add(h)
        records.append({"id": rid, "text": text, "request": request_text, "meta": rec})

    if not records:
        log.info("No new records to process. Aggregating existing results (if any)...")
        # Aggregate and plot existing
        try:
            counts = _compute_counts_from_results(results_path)
            _write_counts_files(counts, out_dir)
            _plot_counts(counts, out_dir)
        except Exception as e:
            log.warning(f"Aggregation failed: {e}")
        # Flags (overall)
        try:
            fcounts = _compute_flag_counts_from_results(results_path)
            _write_flag_counts_files(fcounts, out_dir)
            _plot_flag_counts(fcounts, out_dir)
        except Exception as e:
            log.warning(f"Flag aggregation failed: {e}")
        # Grouped
        try:
            gcounts = _compute_group_counts(results_path, Path(cfg.data.input_path))
            if gcounts:
                _write_group_counts(gcounts, out_dir)
        except Exception as e:
            log.warning(f"Grouped aggregation failed: {e}")
        # Grouped flags
        try:
            gfcounts = _compute_group_flag_counts(
                results_path, Path(cfg.data.input_path)
            )
            if gfcounts:
                _write_group_flag_counts(gfcounts, out_dir)
        except Exception as e:
            log.warning(f"Grouped flag aggregation failed: {e}")
        return

    log.info(f"Processing {len(records)} records...")

    rf = open(results_path, "a", encoding="utf-8")
    ef = open(errors_path, "a", encoding="utf-8")

    total_ok = 0
    total_err = 0
    total_completion_tokens = 0
    total_reasoning_tokens = 0

    bs = max(1, int(cfg.runtime.batch_size))
    provider_kwargs = _build_provider_kwargs(cfg.model.litellm_model_name)

    for start in range(0, len(records), bs):
        batch = records[start : start + bs]

        inputs = []
        for rec in batch:
            prompt = classify_prompt(rec.get("request", ""), rec.get("text", ""))
            inputs.append(
                [
                    {"role": "user", "content": prompt},
                ]
            )

        try:
            responses = litellm.batch_completion(
                model=cfg.model.litellm_model_name,
                messages=inputs,
                temperature=cfg.sampling.temperature,
                max_tokens=cfg.sampling.max_tokens,
                api_base=cfg.model.api_base,
                api_key=cfg.model.api_key,
                max_workers=100,
                num_retries=3,
            )
            # Coerce to list to avoid generators/iterables that might be empty on reuse
            try:
                responses = list(responses)
            except Exception:
                responses = []
        except Exception as e:
            for rec in batch:
                err_obj = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "error": f"batch_error: {e}",
                    "ts": datetime.utcnow().isoformat(),
                }
                ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                total_err += 1
            ef.flush()
            time.sleep(random.uniform(1.0, 5.0))
            continue

        # Fallback if batch return is unexpectedly empty or length mismatch
        if not responses or len(responses) != len(batch):
            log.warning(
                "batch_completion returned %s responses for %s inputs; falling back to per-item calls",
                len(responses) if isinstance(responses, list) else "?",
                len(batch),
            )
            for rec in batch:
                try:
                    single_resp = litellm.completion(
                        model=cfg.model.litellm_model_name,
                        messages=[
                            {
                                "role": "system",
                                "content": "You are a JSON-only classifier. Output JSON only.",
                            },
                            {
                                "role": "user",
                                "content": classify_prompt(
                                    rec.get("request", ""), rec.get("text", "")
                                ),
                            },
                        ],
                        temperature=cfg.sampling.temperature,
                        max_tokens=cfg.sampling.max_tokens,
                        api_base=cfg.model.api_base,
                        api_key=cfg.model.api_key,
                        **provider_kwargs,
                    )
                except Exception as e:
                    err_obj = {
                        "id": rec["id"],
                        "text": rec["text"],
                        "request": rec.get("request", ""),
                        "error": f"api_error: {e}",
                        "ts": datetime.utcnow().isoformat(),
                    }
                    ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                    total_err += 1
                    continue

                raw_text, raw_reasoning, ct, rt = _response_to_text_and_usage(
                    single_resp
                )
                total_completion_tokens += ct
                total_reasoning_tokens += rt

                parsed, ok, perr = parse_novelty_output(raw_text)
                if not ok:
                    err_obj = {
                        "id": rec["id"],
                        "text": rec["text"],
                        "request": rec.get("request", ""),
                        "raw_output": raw_text,
                        "error": perr,
                        "ts": datetime.utcnow().isoformat(),
                    }
                    ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                    total_err += 1
                    continue

                out = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "classification": parsed,
                    "raw_output": raw_text,
                    "raw_reasoning": raw_reasoning,
                    "usage": {"completion_tokens": ct, "reasoning_tokens": rt},
                    "model": cfg.model.litellm_model_name,
                    "ts": datetime.utcnow().isoformat(),
                }
                # Surface category grouping if present
                try:
                    rec0 = rec.get("meta", {})
                    cat = rec.get("category") or rec0.get("category")
                    if isinstance(cat, str) and cat:
                        out["category"] = cat
                        out["group_key"] = cat
                        out["group_key_fields"] = {"category": cat}
                except Exception:
                    pass

                rf.write(json.dumps(out, ensure_ascii=False) + "\n")
                total_ok += 1

            rf.flush()
            ef.flush()
            time.sleep(random.uniform(0.5, 2.0))
            continue

        for rec, resp in zip(batch, responses):
            if isinstance(resp, Exception):
                err_obj = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "error": f"api_error: {resp}",
                    "ts": datetime.utcnow().isoformat(),
                }
                ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                total_err += 1
                continue

            raw_text, raw_reasoning, ct, rt = _response_to_text_and_usage(resp)
            total_completion_tokens += ct
            total_reasoning_tokens += rt

            parsed, ok, perr = parse_novelty_output(raw_text)
            if not ok:
                err_obj = {
                    "id": rec["id"],
                    "text": rec["text"],
                    "request": rec.get("request", ""),
                    "raw_output": raw_text,
                    "error": perr,
                    "ts": datetime.utcnow().isoformat(),
                }
                ef.write(json.dumps(err_obj, ensure_ascii=False) + "\n")
                total_err += 1
                continue

            # Surface grouping info from original input where possible
            gk: Optional[str] = None
            gkf: Optional[dict] = None
            try:
                rec0 = rec.get("meta", {})
                # category → group
                cat = rec.get("category") or rec0.get("category")
                if isinstance(cat, str) and cat:
                    gk = cat
                    gkf = {"category": cat}
            except Exception:
                pass

            out = {
                "id": rec["id"],
                "text": rec["text"],
                "request": rec.get("request", ""),
                "classification": parsed,
                "raw_output": raw_text,
                "raw_reasoning": raw_reasoning,
                "usage": {"completion_tokens": ct, "reasoning_tokens": rt},
                "model": cfg.model.litellm_model_name,
                "ts": datetime.utcnow().isoformat(),
            }
            # Also include the category directly if available
            try:
                if gk is not None:
                    out["category"] = gk
            except Exception:
                pass
            if gk is not None:
                out["group_key"] = gk
            if gkf is not None:
                out["group_key_fields"] = gkf
            rf.write(json.dumps(out, ensure_ascii=False) + "\n")
            total_ok += 1

        rf.flush()
        ef.flush()
        time.sleep(random.uniform(0.5, 2.0))

    rf.close()
    ef.close()

    try:
        summary = {
            "total_ok": total_ok,
            "total_err": total_err,
            "total_completion_tokens": total_completion_tokens,
            "total_reasoning_tokens": total_reasoning_tokens,
            "model": cfg.model.litellm_model_name,
            "sampling": {
                "temperature": cfg.sampling.temperature,
                "max_tokens": cfg.sampling.max_tokens,
            },
        }
        with open(summary_path, "w", encoding="utf-8") as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)
        log.info(
            f"Done. OK={total_ok} ERR={total_err}. Summary saved to {summary_path}"
        )
    except Exception as e:
        log.warning(f"Failed to write summary: {e}")

    # Aggregate and plot
    try:
        counts = _compute_counts_from_results(results_path)
        _write_counts_files(counts, out_dir)
        _plot_counts(counts, out_dir)
    except Exception as e:
        log.warning(f"Failed to aggregate/plot classification frequencies: {e}")

    # Flags aggregation & plots
    try:
        fcounts = _compute_flag_counts_from_results(results_path)
        _write_flag_counts_files(fcounts, out_dir)
        _plot_flag_counts(fcounts, out_dir)
    except Exception as e:
        log.warning(f"Failed to aggregate/plot flag frequencies: {e}")

    # Grouped aggregation by input category
    try:
        gcounts = _compute_group_counts(results_path, Path(cfg.data.input_path))
        if gcounts:
            _write_group_counts(gcounts, out_dir)
            log.info(
                "Wrote grouped classification counts and plots under %s/grouped",
                out_dir,
            )
        else:
            log.info("No grouped counts (no results or no input mapping)")
    except Exception as e:
        log.warning(f"Failed grouped aggregation: {e}")

    # Grouped flag aggregation & plots
    try:
        gfcounts = _compute_group_flag_counts(results_path, Path(cfg.data.input_path))
        if gfcounts:
            _write_group_flag_counts(gfcounts, out_dir)
            log.info(
                "Wrote grouped flag counts and plots under %s/grouped",
                out_dir,
            )
        else:
            log.info("No grouped flag counts (no results or no input mapping)")
    except Exception as e:
        log.warning(f"Failed grouped flag aggregation: {e}")


if __name__ == "__main__":
    main()
