"""
Smoke tests for the S2B game loop.

These tests run fast (no LLM required) and verify that:
  - The env builds and steps correctly
  - Rule-based agents can complete full episodes
  - A mock LLM listener completes episodes and metrics are collected
  - Action parsing handles normal and edge-case LLM outputs

Run with:  pytest tests/
"""
import numpy as np
import pytest

from meta_rg.s2b_import import ensure_s2b_importable
ensure_s2b_importable()

from meta_rg.env_utils import (
    build_env, no_op_action, ensure_action_shape,
    get_step_prompt_text, get_intro_prompt_text,
)
from meta_rg.agents.rule_based import build_rule_based_agents
from meta_rg.game_loop import parse_action, run_episode
from meta_rg.metrics import aggregate_seeds


# ── Shared fixtures ────────────────────────────────────────────────────────────

GAME_KW = dict(
    nbr_communication_rounds=1,
    nbr_distractors=0,
    descriptive=True,
    vocab_size=6,
    max_sentence_length=3,
    provide_listener_feedback=True,
)


@pytest.fixture(scope="module")
def env_and_agents():
    env = build_env(
        nbr_latents=3,
        min_nbr_values_per_latent=2,
        max_nbr_values_per_latent=5,
        nbr_object_centric_samples=1,
        sampling_strategy="component-focused-1shot",
        seed=42,
    )
    speaker, listener = build_rule_based_agents()
    return env, speaker, listener


# ── Action parsing ─────────────────────────────────────────────────────────────

class TestParseAction:
    def _parse(self, text, agent_idx=1):
        return parse_action(
            text, agent_idx,
            nbr_distractors=0, descriptive=True,
            vocab_size=6, max_sentence_length=3,
        )

    def test_typical_response(self):
        a = self._parse("0 1 2 3")
        assert a["decision"] == 0
        assert list(a["communication_channel"][0]) == [1, 2, 3]

    def test_decision_clamped(self):
        a = self._parse("99 1 2 3")
        assert a["decision"] <= 1  # max for descriptive=True, nbr_distractors=0

    def test_comm_tokens_clamped(self):
        a = self._parse("0 99 99 99")
        assert all(t <= 5 for t in a["communication_channel"][0])  # vocab_size-1=5

    def test_empty_response(self):
        a = self._parse("")
        assert a["decision"] == 0
        assert a["communication_channel"].shape == (1, 3)

    def test_partial_response(self):
        a = self._parse("1")
        assert a["decision"] == 1
        assert list(a["communication_channel"][0]) == [0, 0, 0]

    def test_speaker_decision_clamped(self):
        a = self._parse("5 1 2 3", agent_idx=0)
        assert a["decision"] <= 1  # speaker decision in {0,1}


# ── Helpers ────────────────────────────────────────────────────────────────────

class TestActionHelpers:
    def test_no_op_shape(self):
        a = no_op_action(3)
        assert a["decision"] == 0
        assert a["communication_channel"].shape == (1, 3)

    def test_ensure_action_shape_1d(self):
        a = {"decision": 0, "communication_channel": np.array([1, 2, 3])}
        out = ensure_action_shape(a, 3)
        assert out["communication_channel"].shape == (1, 3)

    def test_ensure_action_shape_2d(self):
        a = {"decision": 0, "communication_channel": np.zeros((1, 3))}
        out = ensure_action_shape(a, 3)
        assert out["communication_channel"].shape == (1, 3)


# ── Episode loop ───────────────────────────────────────────────────────────────

class TestEpisodeLoop:
    def test_rule_based_episode_completes(self, env_and_agents):
        env, speaker, listener = env_and_agents
        env.seed(100)
        result = run_episode(env, speaker, lm_generate=None,
                             rb_listener=listener, game_kwargs=GAME_KW)
        assert "zsct_acc" in result
        assert "n_test" in result
        assert result["n_test"] > 0
        assert result["n_train"] > 0
        assert 0.0 <= result["zsct_acc"] <= 1.0

    def test_mock_llm_episode_completes(self, env_and_agents):
        env, speaker, _ = env_and_agents
        env.seed(200)

        def mock_lm(prompt):
            return "0 1 2 3"

        result = run_episode(env, speaker, lm_generate=mock_lm,
                             rb_listener=None, game_kwargs=GAME_KW)
        assert result["n_test"] > 0
        assert 0.0 <= result["zsct_acc"] <= 1.0

    def test_multiple_episodes_different_seeds(self, env_and_agents):
        env, speaker, _ = env_and_agents
        results = []
        for ep in range(3):
            env.seed(300 + ep * 50)
            r = run_episode(env, speaker, lm_generate=lambda p: "1 3 2 1",
                            rb_listener=None, game_kwargs=GAME_KW)
            results.append(r)
        assert all(r["n_test"] > 0 for r in results)


# ── Metrics ────────────────────────────────────────────────────────────────────

class TestMetrics:
    def test_aggregate_seeds_single(self):
        eps = [{"zsct_acc": 0.8, "n_test": 10}, {"zsct_acc": 0.6, "n_test": 10}]
        agg = aggregate_seeds([eps])
        assert agg["mean"] == pytest.approx(70.0, abs=0.1)
        assert agg["n"] == 2  # n_episodes, not n_seeds
        # std over 2 episodes: values [80, 60], mean=70, std=~14.1
        assert agg["std"] == pytest.approx(14.1, abs=0.1)

    def test_aggregate_seeds_multi(self):
        seed1 = [{"zsct_acc": 1.0, "n_test": 5}]
        seed2 = [{"zsct_acc": 0.0, "n_test": 5}]
        agg = aggregate_seeds([seed1, seed2])
        assert agg["mean"] == pytest.approx(50.0, abs=0.1)
        assert agg["n"] == 2  # n_episodes (1 per seed × 2 seeds)
        assert len(agg["per_seed"]) == 2

    def test_aggregate_empty(self):
        agg = aggregate_seeds([])
        assert agg["mean"] == 0.0
        assert agg["n"] == 0


# ── Discussion mode ────────────────────────────────────────────────────────────

class _MockDiscussionBackend:
    """Minimal backend that records generate_chat call argument counts and last user content."""
    def __init__(self):
        self.call_msg_counts: list[int] = []
        self.last_user_contents: list[str] = []

    def generate(self, text: str) -> str:
        return "0 1 2 3"

    def generate_chat(self, messages: list[dict]) -> str:
        self.call_msg_counts.append(len(messages))
        user_msgs = [m["content"] for m in messages if m.get("role") == "user"]
        self.last_user_contents.append(user_msgs[-1] if user_msgs else "")
        return "0 1 2 3"


class TestDiscussionMode:
    def test_env_emits_step_and_intro_prompts(self):
        env = build_env(discussion_mode=True, seed=42)
        obs, infos = env.reset()
        speaker, listener = build_rule_based_agents()
        a0 = ensure_action_shape(speaker.next_action(state=obs[0], infos=infos[0]), 3)
        a1 = ensure_action_shape(listener.next_action(state=obs[1], infos=infos[1]), 3)
        obs2, _, _, infos2 = env.step([a0, a1])

        assert "step_prompt" in infos2[1], "step_prompt missing from listener info"
        assert "intro_prompt" in infos2[1], "intro_prompt missing from listener info"
        step = get_step_prompt_text(infos2[1])
        intro = get_intro_prompt_text(infos2[1])
        assert len(intro) > 50, "intro should contain game rules"
        assert "At game" in step or "Starting game" in step, "step should reference current game"
        assert "Question #1" in step, "step should contain question"

    def test_env_no_extra_keys_when_discussion_off(self):
        env = build_env(discussion_mode=False, seed=42)
        obs, infos = env.reset()
        speaker, listener = build_rule_based_agents()
        a0 = ensure_action_shape(speaker.next_action(state=obs[0], infos=infos[0]), 3)
        a1 = ensure_action_shape(listener.next_action(state=obs[1], infos=infos[1]), 3)
        _, _, _, infos2 = env.step([a0, a1])

        assert "step_prompt" not in infos2[1], "step_prompt should be absent by default"
        assert "intro_prompt" not in infos2[1], "intro_prompt should be absent by default"
        assert "prompt" in infos2[1], "full prompt must still be present"

    def test_run_episode_uses_generate_chat(self):
        env = build_env(discussion_mode=True, seed=99)
        speaker, _ = build_rule_based_agents()
        backend = _MockDiscussionBackend()

        result = run_episode(
            env, speaker,
            lm_generate=backend.generate,
            game_kwargs=GAME_KW,
            discussion_backend=backend,
        )

        assert result["n_test"] > 0
        assert len(backend.call_msg_counts) > 0, "generate_chat was never called"

    def test_conversation_grows_alternating(self):
        """First call has 1 message (intro+step); each subsequent adds 2 (assistant+user)."""
        env = build_env(discussion_mode=True, seed=7)
        speaker, _ = build_rule_based_agents()
        backend = _MockDiscussionBackend()

        run_episode(
            env, speaker,
            lm_generate=backend.generate,
            game_kwargs=GAME_KW,
            discussion_backend=backend,
        )

        counts = backend.call_msg_counts
        assert counts[0] == 1, f"first call should have 1 message, got {counts[0]}"
        for i in range(1, len(counts)):
            assert counts[i] == counts[i - 1] + 2, (
                f"call {i}: expected {counts[i-1]+2} messages, got {counts[i]}"
            )

    def test_step_prompt_contains_speaker_message(self):
        """Each generate_chat call must include the speaker's communication channel."""
        env = build_env(discussion_mode=True, seed=15)
        speaker, _ = build_rule_based_agents()
        backend = _MockDiscussionBackend()

        run_episode(
            env, speaker,
            lm_generate=backend.generate,
            game_kwargs=GAME_KW,
            discussion_backend=backend,
        )

        for i, content in enumerate(backend.last_user_contents):
            assert "partner has sent" in content or i == 0 and "Starting game" in content, (
                f"Call {i}: step_text missing speaker message. Content[:200]: {content[:200]!r}"
            )

    def test_feedback_appears_in_listener_turn(self):
        """Feedback from game N must be visible in the LLM's turn for game N+1 (not lost at step_id=0)."""
        env = build_env(discussion_mode=True, seed=21, provide_listener_feedback=True)
        speaker, _ = build_rule_based_agents()
        backend = _MockDiscussionBackend()

        run_episode(
            env, speaker,
            lm_generate=backend.generate,
            game_kwargs=GAME_KW,
            discussion_backend=backend,
        )

        # From game index 1 onward the user turn must contain a feedback phrase
        # ("won" or "lost") carried over from the previous game's result.
        # From game index 1 onward the user turn must contain the listener-feedback phrase:
        # the speaker's exact stimulus shown to the listener at the feedback step.
        feedback_seen = [
            ("exact stimulus" in c and ("won" in c or "lost" in c))
            for c in backend.last_user_contents
        ]
        assert any(feedback_seen[1:]), (
            "No LLM turn (after game 0) contains listener feedback "
            "(speaker's exact stimulus + won/lost). "
            f"last_user_contents[:3] = {backend.last_user_contents[:3]}"
        )

    def test_feedback_absent_when_toggled_off(self):
        """When provide_listener_feedback=False, speaker's exact stimulus must NOT appear."""
        game_kw_no_fb = dict(GAME_KW, provide_listener_feedback=False)
        env = build_env(discussion_mode=True, seed=33, provide_listener_feedback=False)
        speaker, _ = build_rule_based_agents()
        backend = _MockDiscussionBackend()

        run_episode(
            env, speaker,
            lm_generate=backend.generate,
            game_kwargs=game_kw_no_fb,
            discussion_backend=backend,
        )

        for i, content in enumerate(backend.last_user_contents):
            assert "exact stimulus" not in content, (
                f"Call {i}: speaker's exact stimulus should be absent when "
                f"provide_listener_feedback=False. Content[:200]: {content[:200]!r}"
            )

    def test_discussion_cot_episode_completes(self):
        """DiscussionCotBackend wires into run_episode as discussion_backend."""
        from meta_rg.prompt_strategy import DiscussionCotBackend
        from meta_rg.backends.base import BaseBackend

        class StubBackend(BaseBackend):
            def generate(self, text: str) -> str:
                return "Answer: 0"
            def generate_chat(self, messages: list) -> str:
                return "Answer: 0"

        stub = StubBackend()
        backend_lm_stub = type("LM", (), {
            "prompt_token_lengths": [],
            "completion_token_lengths": [],
            "reset_token_stats": lambda self: None,
            "_track_tokens": lambda self, p, c: None,
        })()

        disc_cot = DiscussionCotBackend.__new__(DiscussionCotBackend)
        disc_cot._backend = stub
        disc_cot._backend_lm = backend_lm_stub
        disc_cot._native_history = []
        disc_cot.reset_stats = lambda: None
        disc_cot.n_truncated = disc_cot.n_adapter_errors = 0
        disc_cot.n_re_prompt_truncated = disc_cot.n_format_errors = 0

        env = build_env(discussion_mode=True, seed=55)
        speaker, _ = build_rule_based_agents()
        result = run_episode(
            env, speaker,
            lm_generate=stub.generate,
            game_kwargs=GAME_KW,
            discussion_backend=disc_cot,
        )
        assert result["n_test"] > 0
        assert 0.0 <= result["zsct_acc"] <= 1.0

    def test_generate_chat_default_fallback(self):
        """BaseBackend.generate_chat flattens messages and calls generate()."""
        from meta_rg.backends.base import BaseBackend

        class MinimalBackend(BaseBackend):
            def __init__(self):
                self.received = []
            def generate(self, text: str) -> str:
                self.received.append(text)
                return "0 1 2 3"

        b = MinimalBackend()
        messages = [
            {"role": "user", "content": "intro text"},
            {"role": "assistant", "content": "0 1 2 3"},
            {"role": "user", "content": "next step"},
        ]
        result = b.generate_chat(messages)
        assert result == "0 1 2 3"
        assert len(b.received) == 1
        assert "intro text" in b.received[0]
        assert "next step" in b.received[0]
