"""
Factory for the per-game-step W&B logging callback used in eval_condition.

Usage:
    reset_snap, step_callback = make_step_logger(
        backend=backend,
        cot_gen=_cot_gen,       # None if no CoT strategy
        use_wandb=use_wandb,
        global_step=_global_step,
        get_current_seed=lambda: _current_seed[0],
        get_current_ep=lambda: _current_ep[0],
    )
    # At episode start (after cot_gen.reset_stats()):
    reset_snap()
    # Pass step_callback into run_episode(... step_callback=step_callback)
"""
from __future__ import annotations

from typing import Callable

try:
    import wandb
except ImportError:
    wandb = None  # type: ignore[assignment]


def make_step_logger(
    *,
    backend,
    cot_gen,
    use_wandb: bool,
    global_step: list[int],
    get_current_seed: Callable[[], int],
    get_current_ep: Callable[[], int],
):
    """
    Return (reset_snap, step_callback).

    reset_snap()  — call at episode start after cot_gen.reset_stats().
    step_callback(game_idx, is_test, correct) — passed to run_episode.
    """
    _snap: dict = {}

    def _read_tokens() -> tuple[int, int, int]:
        pt  = getattr(backend, "total_prompt_tokens",     0) or 0
        ct  = getattr(backend, "total_completion_tokens", 0) or 0
        cac = getattr(backend, "total_cached_tokens",     0) or 0
        return pt, ct, cac

    def _read_errors() -> tuple[int, int, int, int]:
        if cot_gen is None:
            return 0, 0, 0, 0
        return (
            getattr(cot_gen, "n_truncated",           0) or 0,
            getattr(cot_gen, "n_adapter_errors",      0) or 0,
            getattr(cot_gen, "n_re_prompt_truncated", 0) or 0,
            getattr(cot_gen, "n_format_errors",       0) or 0,
        )

    def reset_snap() -> None:
        pt, ct, cac = _read_tokens()
        _snap["prompt"]  = pt
        _snap["compl"]   = ct
        _snap["cached"]  = cac
        # CoT errors are reset per-episode by reset_stats(); baseline is 0.
        _snap["trunc"]   = 0
        _snap["adap"]    = 0
        _snap["reprmt"]  = 0
        _snap["fmt"]     = 0

    def step_callback(game_idx: int, is_test: bool, correct: bool) -> None:
        pt, ct, cac = _read_tokens()
        trunc, adap, reprmt, fmt = _read_errors()

        dpt    = pt    - _snap["prompt"]
        dct    = ct    - _snap["compl"]
        dcac   = cac   - _snap["cached"]
        dtrunc = trunc - _snap["trunc"]
        dadap  = adap  - _snap["adap"]
        dreprmt = reprmt - _snap["reprmt"]
        dfmt   = fmt   - _snap["fmt"]

        _lcp = getattr(backend, "last_call_prompt_tokens",  0) or 0
        _lcc = getattr(backend, "last_call_cached_tokens",  0) or 0
        cache_hit_rate = _lcc / _lcp if _lcp > 0 else 0.0

        if use_wandb and wandb is not None:
            wandb.log(
                {
                    "game/reward":                  1.0 if correct else 0.0,
                    "game/is_test":                 1 if is_test else 0,
                    "game/seed":                    get_current_seed(),
                    "game/episode_idx":             get_current_ep(),
                    "game/game_idx_in_episode":     game_idx,
                    "game/tokens/prompt_delta":     dpt,
                    "game/tokens/completion_delta": dct,
                    "game/tokens/cached_delta":     dcac,
                    "game/tokens/cache_hit_rate":   cache_hit_rate,
                    "game/errors/n_truncated":      dtrunc,
                    "game/errors/n_adapter_errors": dadap,
                    "game/errors/n_re_prompt":      dreprmt,
                    "game/errors/n_format_errors":  dfmt,
                },
                step=global_step[0],
            )

        # Advance snapshot
        _snap["prompt"]  = pt
        _snap["compl"]   = ct
        _snap["cached"]  = cac
        _snap["trunc"]   = trunc
        _snap["adap"]    = adap
        _snap["reprmt"]  = reprmt
        _snap["fmt"]     = fmt

        global_step[0] += 1

    return reset_snap, step_callback
