"""
Unit tests for per-game-step callback in run_episode and make_step_logger.
No network, no real W&B — all external objects are faked.
"""
from __future__ import annotations

import types
from unittest.mock import MagicMock, patch

import pytest


# ---------------------------------------------------------------------------
# Minimal fakes for run_episode
# ---------------------------------------------------------------------------

def _make_rb_speaker():
    rb = MagicMock()
    rb.reset.return_value = None
    rb.get_utterance.return_value = "utterance"
    return rb


def _fake_lm(prompt_text: str) -> str:
    return "42"


# ---------------------------------------------------------------------------
# Helper: run run_episode with run_one_game stubbed out
# ---------------------------------------------------------------------------

def _run_episode_with_callback(n_games, rewards=None, callback=None):
    """
    Call run_episode with run_one_game stubbed so no real env is needed.
    Games alternate is_test: game 0 → test, game 1 → train, game 2 → test, …
    """
    from meta_rg.game_loop import run_episode
    import meta_rg.game_loop as gl

    rw = rewards if rewards is not None else [1.0] * n_games
    assert len(rw) == n_games

    env = MagicMock()
    env.reset.return_value = (
        [types.SimpleNamespace()],
        [{"mode": "train"}, {}],
    )
    rb = _make_rb_speaker()

    call_idx = [0]

    def fake_rog(env, obs, infos, rb_speaker, lm_generate, rb_listener=None, **kw):
        i = call_idx[0]
        call_idx[0] += 1
        done = call_idx[0] >= n_games
        new_infos = [{"mode": "test" if i % 2 == 0 else "train"}, {}]
        return rw[i], done, obs, new_infos

    with patch.object(gl, "run_one_game", side_effect=fake_rog):
        with patch.object(gl, "_env_reset", return_value=(
            [types.SimpleNamespace()],
            [{"mode": "train"}, {}],
        )):
            result = run_episode(
                env, rb, _fake_lm,
                step_callback=callback,
            )
    return result


# ---------------------------------------------------------------------------
# Test 1: callback fires exactly N times
# ---------------------------------------------------------------------------

def test_callback_fires_per_game():
    fired = []

    def cb(game_idx, is_test, correct):
        fired.append(game_idx)

    _run_episode_with_callback(4, callback=cb)
    assert fired == [0, 1, 2, 3]


# ---------------------------------------------------------------------------
# Test 2: callback receives correct args
# ---------------------------------------------------------------------------

def test_callback_receives_correct_args():
    calls = []

    def cb(game_idx, is_test, correct):
        calls.append((game_idx, is_test, correct))

    # is_test is read from infos BEFORE run_one_game, so initial infos (train)
    # determine game 0; new_infos from game 0 (i=0 even → test) determine game 1; etc.
    # Pattern: game 0 → train (False), game 1 → test (True), game 2 → train, game 3 → test
    # rewards: [1, 0, 1, 0] → correct=[True, False, True, False]
    _run_episode_with_callback(4, rewards=[1.0, 0.0, 1.0, 0.0], callback=cb)

    assert calls[0] == (0, False, True)
    assert calls[1] == (1, True,  False)
    assert calls[2] == (2, False, True)
    assert calls[3] == (3, True,  False)


# ---------------------------------------------------------------------------
# Test 3: no callback is a no-op (backward compat)
# ---------------------------------------------------------------------------

def test_no_callback_is_noop():
    result = _run_episode_with_callback(3, callback=None)
    assert "zsct_acc" in result


# ---------------------------------------------------------------------------
# Helpers for StepLogger tests
# ---------------------------------------------------------------------------

def _make_fake_backend(prompt=0, completion=0, cached=0,
                       last_call_prompt=0, last_call_cached=0):
    return types.SimpleNamespace(
        total_prompt_tokens=prompt,
        total_completion_tokens=completion,
        total_cached_tokens=cached,
        last_call_prompt_tokens=last_call_prompt,
        last_call_cached_tokens=last_call_cached,
    )


def _make_fake_cot_gen(trunc=0, adap=0, reprmt=0, fmt=0):
    return types.SimpleNamespace(
        n_truncated=trunc,
        n_adapter_errors=adap,
        n_re_prompt_truncated=reprmt,
        n_format_errors=fmt,
    )


def _make_logger(backend, cot_gen=None, use_wandb=True):
    from meta_rg.step_logger import make_step_logger
    gs = [0]
    return make_step_logger(
        backend=backend,
        cot_gen=cot_gen,
        use_wandb=use_wandb,
        global_step=gs,
        get_current_seed=lambda: 0,
        get_current_ep=lambda: 0,
    ), gs


# ---------------------------------------------------------------------------
# Test 4: token deltas are per-game, not cumulative
# ---------------------------------------------------------------------------

def test_token_deltas_correct():
    from meta_rg.step_logger import make_step_logger
    backend = _make_fake_backend()
    logged = []

    (reset_snap, cb), gs = _make_logger(backend, use_wandb=True)

    with patch("meta_rg.step_logger.wandb") as mock_wb:
        mock_wb.log.side_effect = lambda m, step=None: logged.append(dict(m))

        reset_snap()

        backend.total_prompt_tokens      = 100
        backend.total_completion_tokens  = 20
        backend.total_cached_tokens      = 0
        backend.last_call_prompt_tokens  = 100
        backend.last_call_cached_tokens  = 0
        cb(0, True, True)

        backend.total_prompt_tokens      = 300
        backend.total_completion_tokens  = 50
        backend.total_cached_tokens      = 150
        backend.last_call_prompt_tokens  = 200
        backend.last_call_cached_tokens  = 150
        cb(1, False, False)

    assert logged[0]["game/tokens/prompt_delta"]     == 100
    assert logged[0]["game/tokens/completion_delta"] == 20
    assert logged[0]["game/tokens/cached_delta"]     == 0
    assert logged[0]["game/tokens/cache_hit_rate"]   == 0.0

    assert logged[1]["game/tokens/prompt_delta"]     == 200
    assert logged[1]["game/tokens/completion_delta"] == 30
    assert logged[1]["game/tokens/cached_delta"]     == 150
    assert abs(logged[1]["game/tokens/cache_hit_rate"] - 0.75) < 1e-9


# ---------------------------------------------------------------------------
# Test 5: CoT error deltas are per-game
# ---------------------------------------------------------------------------

def test_cot_error_deltas_correct():
    backend = _make_fake_backend()
    cot_gen = _make_fake_cot_gen()
    logged = []

    (reset_snap, cb), _ = _make_logger(backend, cot_gen=cot_gen, use_wandb=True)

    with patch("meta_rg.step_logger.wandb") as mock_wb:
        mock_wb.log.side_effect = lambda m, step=None: logged.append(dict(m))

        reset_snap()

        cot_gen.n_truncated = 1
        cb(0, True, True)

        cot_gen.n_format_errors = 2
        cb(1, False, True)

    assert logged[0]["game/errors/n_truncated"]     == 1
    assert logged[0]["game/errors/n_format_errors"] == 0

    assert logged[1]["game/errors/n_truncated"]     == 0
    assert logged[1]["game/errors/n_format_errors"] == 2


# ---------------------------------------------------------------------------
# Test 6: reset_snap zeros error baseline, advances token baseline
# ---------------------------------------------------------------------------

def test_snap_reset_between_episodes():
    backend = _make_fake_backend(prompt=500, completion=100, cached=50)
    cot_gen = _make_fake_cot_gen(trunc=3, adap=1)
    logged = []

    (reset_snap, cb), _ = _make_logger(backend, cot_gen=cot_gen, use_wandb=True)

    with patch("meta_rg.step_logger.wandb") as mock_wb:
        mock_wb.log.side_effect = lambda m, step=None: logged.append(dict(m))

        # Simulate start of new episode: cot_gen.reset_stats() zeroes errors
        cot_gen.n_truncated      = 0
        cot_gen.n_adapter_errors = 0
        reset_snap()

        backend.total_prompt_tokens = 550
        cot_gen.n_truncated = 1
        cb(0, True, True)

    assert logged[0]["game/tokens/prompt_delta"] == 50   # delta from 500, not 0
    assert logged[0]["game/errors/n_truncated"]  == 1    # delta from 0 (reset)


# ---------------------------------------------------------------------------
# Test 7: non-HF backend (no token attrs) produces 0 deltas
# ---------------------------------------------------------------------------

def test_non_hf_backend_skips_token_deltas():
    backend = types.SimpleNamespace()   # no total_* attributes
    logged = []

    (reset_snap, cb), _ = _make_logger(backend, use_wandb=True)

    with patch("meta_rg.step_logger.wandb") as mock_wb:
        mock_wb.log.side_effect = lambda m, step=None: logged.append(dict(m))
        reset_snap()
        cb(0, True, True)

    assert logged[0]["game/tokens/prompt_delta"]     == 0
    assert logged[0]["game/tokens/completion_delta"] == 0
    assert logged[0]["game/tokens/cached_delta"]     == 0


# ---------------------------------------------------------------------------
# Test 8: use_wandb=False → wandb.log never called, global_step still advances
# ---------------------------------------------------------------------------

def test_wandb_not_logged_when_disabled():
    backend = _make_fake_backend()

    (reset_snap, cb), gs = _make_logger(backend, use_wandb=False)

    with patch("meta_rg.step_logger.wandb") as mock_wb:
        reset_snap()
        backend.total_prompt_tokens = 100
        cb(0, True, True)
        cb(1, False, False)
        mock_wb.log.assert_not_called()

    assert gs[0] == 2
