from __future__ import annotations

import argparse
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from .config import FINAL_LOGS_DIR, PROOF_LOGS_DIR, STATEMENT_LOGS_DIR
from .log_utils import extract_tokens_used_from_file
from .metrics import finish_run, log_event, start_run


_LABEL_RE = re.compile(r'"label"\\s*:\\s*"([^"]+)"')
_LEAN_FILE_RE = re.compile(r'"lean_file"\\s*:\\s*"([^"]+)"')


def _read_head_text(path: Path, *, head_bytes: int = 200_000) -> str:
    try:
        with path.open("rb") as f:
            head = f.read(head_bytes)
    except FileNotFoundError:
        return ""
    return head.decode("utf-8", errors="replace")


def _extract_label(path: Path) -> str | None:
    text = _read_head_text(path)
    m = _LABEL_RE.search(text)
    return m.group(1) if m else None


def _extract_lean_file(path: Path) -> str | None:
    text = _read_head_text(path)
    m = _LEAN_FILE_RE.search(text)
    return m.group(1) if m else None


def _parse_idx(name: str) -> int | None:
    m = re.search(r"idx(\\d+)", name)
    if m:
        return int(m.group(1))
    m = re.search(r"(?:^|_)agent_[ab]_(\\d+)(?:\\D|$)", name)
    if m:
        return int(m.group(1))
    m = re.search(r"(?:^|_)proof_agent_[abcd]_(\\d+)(?:\\D|$)", name)
    if m:
        return int(m.group(1))
    return None


def _parse_agent(name: str) -> str | None:
    m = re.search(r"(?:^|_)agent_([abcd])(?:_|\\.|$)", name)
    if m:
        return m.group(1)
    return None


def _parse_final_task_id(name: str) -> str | None:
    # legacy: final_agent_a_26_L839.log / final_agent_b_0_compile.log
    m = re.search(r"(?:^|_)(\\d+_(?:L\\d+|compile|warnings))(?:_|\\.|$)", name)
    if m:
        return m.group(1)
    # new: final_agent_a_<task>_attempt1_<...>.log
    m = re.search(r"(?:^|_)task(\\d+_(?:L\\d+|compile|warnings))(?:_|\\.|$)", name)
    if m:
        return m.group(1)
    return None


@dataclass(frozen=True, slots=True)
class WaveKey:
    stage: str
    key: str


def _collect_tokens_for_dir(
    stage: str,
    log_dir: Path,
    *,
    key_from_name,
    include_label: bool,
    include_lean_file: bool,
) -> tuple[dict[WaveKey, dict[str, Any]], int]:
    waves: dict[WaveKey, dict[str, Any]] = {}
    total_tokens = 0

    tokens_by_wave_agent: dict[WaveKey, dict[str, int]] = defaultdict(lambda: defaultdict(int))
    label_by_wave: dict[WaveKey, str] = {}
    lean_file_by_wave: dict[WaveKey, str] = {}
    files_seen: dict[WaveKey, int] = defaultdict(int)

    candidate_paths = list(log_dir.glob("*.log"))
    candidate_paths.extend(log_dir.glob("agent_*/*.log"))
    for path in sorted(set(candidate_paths)):
        tokens = extract_tokens_used_from_file(path)
        if tokens is None:
            continue
        agent = _parse_agent(path.name) or "unknown"
        key = key_from_name(path.name)
        if key is None:
            continue

        wave = WaveKey(stage=stage, key=str(key))
        files_seen[wave] += 1
        total_tokens += tokens
        tokens_by_wave_agent[wave][agent] += tokens

        if include_label and wave not in label_by_wave:
            label = _extract_label(path)
            if label:
                label_by_wave[wave] = label
        if include_lean_file and wave not in lean_file_by_wave:
            lean_file = _extract_lean_file(path)
            if lean_file:
                lean_file_by_wave[wave] = lean_file

    for wave, per_agent in tokens_by_wave_agent.items():
        data: dict[str, Any] = {
            "stage": wave.stage,
            "task": wave.key,
            "tokens_used_total": sum(per_agent.values()),
            "tokens_used_by_agent": dict(sorted(per_agent.items())),
            "log_file_count": files_seen[wave],
        }
        if include_label and wave in label_by_wave:
            data["label"] = label_by_wave[wave]
        if include_lean_file and wave in lean_file_by_wave:
            data["lean_file"] = lean_file_by_wave[wave]
        waves[wave] = data

    return waves, total_tokens


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Backfill per-task token usage by parsing historical Codex log files."
    )
    parser.add_argument(
        "--name-tag",
        type=str,
        default="from_logs",
        help="Name tag for the metrics run (default: from_logs).",
    )
    args = parser.parse_args()

    run_id = start_run("token_backfill", stage=0, name_tag=args.name_tag, data_file="orchestrator/*_logs")

    statement_waves, statement_total = _collect_tokens_for_dir(
        "statement",
        STATEMENT_LOGS_DIR,
        key_from_name=_parse_idx,
        include_label=True,
        include_lean_file=False,
    )
    proof_waves, proof_total = _collect_tokens_for_dir(
        "proof",
        PROOF_LOGS_DIR,
        key_from_name=_parse_idx,
        include_label=True,
        include_lean_file=False,
    )
    final_waves, final_total = _collect_tokens_for_dir(
        "final",
        FINAL_LOGS_DIR,
        key_from_name=_parse_final_task_id,
        include_label=False,
        include_lean_file=True,
    )

    all_waves = {**statement_waves, **proof_waves, **final_waves}
    for wave in sorted(all_waves.values(), key=lambda d: (d["stage"], str(d["task"]))):
        log_event(run_id, "task_tokens", wave)

    finish_run(
        run_id,
        {
            "pipeline": "token_backfill",
            "statement_waves": len(statement_waves),
            "proof_waves": len(proof_waves),
            "final_waves": len(final_waves),
            "statement_tokens_used": statement_total,
            "proof_tokens_used": proof_total,
            "final_tokens_used": final_total,
            "total_tokens_used": statement_total + proof_total + final_total,
        },
    )


if __name__ == "__main__":
    main()
