"""
Core game-loop primitives for a single Meta-RG episode.

A Meta-RG episode consists of:
  - Supporting phase (train mode): agent learns the episode's semantic structure
  - Querying / ZSCT phase (test mode): novel compositional combinations

One complete RG consists of nbr_communication_rounds + 1 env steps
(+ 1 optional listener-feedback step).

Action routing:
  - Agent 0 (speaker): positionally-disentangled rule-based agent
  - Agent 1 (listener): LLM backend queried via its generate() method
"""
from __future__ import annotations

import re
from typing import Callable

import numpy as np

from meta_rg.env_utils import (
    ensure_action_shape,
    get_intro_prompt_text,
    get_prompt_text,
    get_step_prompt_text,
    int_comm_to_ohe,
    is_test_mode,
    no_op_action,
)
from meta_rg.exceptions import ContextOverflowError
from meta_rg.stats import token_stats

try:
    import weave as _weave
    _weave_op = _weave.op
except ImportError:
    def _weave_op(fn=None, *, name=None, **_kw):
        return fn if fn is not None else (lambda f: f)

# Infos fields to log — excludes prompt bytes but includes latents and stimuli.
_INFOS_LOG_KEYS = (
    "mode", "round_idx", "round_id", "stimulus_idx", "step_idx",
    "end_of_mode", "nbr_games", "nbr_successes", "running_accuracy",
    "nbr_communication_rounds",
    "speaker_exp_latents", "listener_exp_latents",
)


def _to_py(v):
    return v.tolist() if hasattr(v, "tolist") else v


def _slim_info(infos: list) -> dict:
    info = infos[0] if infos else {}
    return {k: _to_py(info[k]) for k in _INFOS_LOG_KEYS if k in info}


_MAX_STIMULUS_BYTES = 4_096


def _slim_obs(obs: list) -> dict:
    """Extract stimuli from speaker and listener observations."""
    import sys
    result = {}
    for role, idx in (("speaker", 0), ("listener", 1)):
        if idx < len(obs) and isinstance(obs[idx], dict) and "stimulus" in obs[idx]:
            raw = _to_py(obs[idx]["stimulus"])
            if sys.getsizeof(raw) <= _MAX_STIMULUS_BYTES:
                result[f"{role}_stimulus"] = raw
            else:
                shape = getattr(obs[idx]["stimulus"], "shape", None)
                result[f"{role}_stimulus"] = f"<truncated shape={shape}>"
    return result


@_weave_op(name="env_reset")
def _log_env_reset(infos: dict, stimuli: dict) -> None:
    return None


@_weave_op(name="env_step")
def _log_env_step(
    listener_decision: int,
    reward: float,
    done: bool,
    infos: dict,
    stimuli: dict,
) -> None:
    return None



def _env_reset(env):
    obs, infos = env.reset()
    _log_env_reset(infos=_slim_info(infos), stimuli=_slim_obs(obs))
    return obs, infos


def _env_step(env, actions: list) -> tuple:
    obs, rewards, done, infos = env.step(actions)
    a1 = actions[1] if len(actions) > 1 else {}
    _log_env_step(
        listener_decision=int(np.asarray(a1.get("decision", -1)).flat[0]) if isinstance(a1, dict) else -1,
        reward=float(rewards[0]) if hasattr(rewards, "__iter__") else float(rewards),
        done=bool(done),
        infos=_slim_info(infos),
        stimuli=_slim_obs(obs),
    )
    return obs, rewards, done, infos


# ── Action parsing ─────────────────────────────────────────────────────────────

def parse_action(
    gen_text: str,
    agent_idx: int,
    nbr_distractors: int,
    descriptive: bool,
    vocab_size: int,
    max_sentence_length: int,
) -> dict:
    """
    Extract decision + communication tokens from LLM-generated text.

    The S2B env expects:
      decision:              int in [0, nbr_distractors + int(descriptive)]
      communication_channel: int64 array (1, max_sentence_length), tokens in [0, vocab_size-1]
    """
    # Strip "Question #N: ..." labels so their ordinal numbers don't pollute parsing.
    cleaned = re.sub(r"Question\s*#\d+[^:]*:\s*", "", gen_text)
    numbers = re.findall(r"\d+", cleaned)

    decision = 0
    if numbers:
        decision = int(numbers[0])
        if agent_idx == 1:  # listener
            max_dec = nbr_distractors + 1 + int(descriptive)
            decision = min(decision, max_dec - 1)
        else:  # speaker
            decision = min(decision, 1)

    comm = np.zeros((1, max_sentence_length), dtype=np.int64)
    for i in range(max_sentence_length):
        if i + 1 < len(numbers):
            tok = int(numbers[i + 1])
            comm[0, i] = max(0, min(tok, vocab_size - 1))

    return {"decision": decision, "communication_channel": comm}


# ── Single-game step ───────────────────────────────────────────────────────────

def run_one_game(
    env,
    obs: list,
    infos: list,
    rb_speaker,
    lm_generate: Callable[[str], str] | None,
    rb_listener=None,
    nbr_communication_rounds: int = 1,
    nbr_distractors: int = 0,
    descriptive: bool = True,
    vocab_size: int = 6,
    max_sentence_length: int = 3,
    provide_listener_feedback: bool = True,
    use_step_prompt: bool = False,
    listener_capture: list | None = None,
    listener_state_capture: list | None = None,
    feedback_state_capture: list | None = None,
) -> tuple[float, bool, list, list]:
    """
    Play one complete RG (one game within an episode).

    If `lm_generate` is None the listener falls back to `rb_listener`.
    Returns (reward, done, obs, infos) after the game completes.
    """
    n_steps = nbr_communication_rounds + 1 + int(provide_listener_feedback)
    final_reward = 0.0
    done = False

    for _ in range(n_steps):
        current_round = infos[0]["round_idx"]

        # Listener-feedback step: both agents no-op.
        # _gen_reward() fires on round_idx==-1 (the listener step), so final_reward
        # is already correct at this point — don't overwrite it here.
        if current_round == -1:
            if rb_listener is not None and hasattr(rb_listener, "observe_feedback"):
                rb_listener.observe_feedback(infos[1])
            if feedback_state_capture is not None:
                feedback_state_capture.clear()
                feedback_state_capture.append(
                    {"listener_exp_latents": infos[1].get("listener_exp_latents")}
                )
            a0 = no_op_action(max_sentence_length)
            a1 = no_op_action(max_sentence_length)
            obs, _, done, infos = _env_step(env, [a0, a1])
            continue

        is_speaker_round = current_round == 0
        is_listener_round = current_round == nbr_communication_rounds

        # ── Speaker (agent 0): always rule-based ──────────────────────────────
        if is_speaker_round:
            a0 = rb_speaker.next_action(state=obs[0], infos=infos[0])
        else:
            a0 = no_op_action(max_sentence_length)

        # ── Listener (agent 1): LLM or rule-based fallback ────────────────────
        if is_listener_round and lm_generate is not None:
            prompt_text = (get_step_prompt_text(infos[1]) if use_step_prompt
                           else get_prompt_text(infos[1]))
            if listener_state_capture is not None:
                listener_state_capture.clear()
                listener_state_capture.append({"obs1": obs[1], "infos1": dict(infos[1])})
            gen_text = lm_generate(prompt_text)
            a1 = parse_action(
                gen_text, 1,
                nbr_distractors, descriptive, vocab_size, max_sentence_length,
            )
        elif rb_listener is not None:
            # Call on EVERY round so the agent can track round_idx internally
            # (round_idx=0 initialises its per_round_decision list).
            listener_infos = dict(infos[1])
            if not is_speaker_round:
                listener_infos["communication_channel"] = int_comm_to_ohe(
                    obs[1]["communication_channel"], vocab_size, max_sentence_length
                )
            a1 = rb_listener.next_action(state=obs[1], infos=listener_infos)
            if is_listener_round and listener_capture is not None:
                step_fn = get_step_prompt_text if use_step_prompt else get_prompt_text
                listener_capture.append({
                    "step_text":        step_fn(infos[1]),
                    "msg_int":          obs[1]["communication_channel"].flatten().tolist(),
                    "listener_latents": infos[1].get("listener_exp_latents"),
                    "decision":         int(a1["decision"].flatten()[0]),
                })
        else:
            a1 = no_op_action(max_sentence_length)

        a0 = ensure_action_shape(a0, max_sentence_length)
        a1 = ensure_action_shape(a1, max_sentence_length)
        obs, rewards, done, infos = _env_step(env, [a0, a1])
        final_reward = float(rewards[0])

    return final_reward, done, obs, infos


# ── Episode-level evaluation ───────────────────────────────────────────────────

def run_episode(
    env,
    rb_speaker,
    lm_generate: Callable[[str], str] | None,
    rb_listener=None,
    game_kwargs: dict | None = None,
    tokenize_fn: Callable[[str], int] | None = None,
    discussion_backend=None,
    step_callback=None,   # Callable[[int, bool, bool], None] | None
    n_few_shot_games: int = 0,
    few_shot_domain: str = "SCS",
    few_shot_nbr_object_centric: int = 1,
) -> dict:
    """
    Reset the env and play all games until `done=True`.
    Returns per-phase accuracy and token-length stats for the episode.

    If tokenize_fn is provided, prompt and completion lengths are measured in
    tokens; otherwise the token stats are omitted from the result dict.
    """
    gkw = game_kwargs or {}
    obs, infos = _env_reset(env)
    rb_speaker.reset()
    if rb_listener is not None:
        rb_listener.reset()

    train_correct = train_total = test_correct = test_total = 0
    done = False

    # Token-length recording lists (shared with discussion wrapper if active).
    prompt_token_lengths: list[int] = []
    completion_token_lengths: list[int] = []

    # Discussion mode: replace lm_generate with a stateful multi-turn wrapper.
    if discussion_backend is not None and lm_generate is not None:
        intro = get_intro_prompt_text(infos[1])
        _conv_history: list[dict] = []
        _first_turn = True
        _disc_tok = tokenize_fn  # snapshot; we disable the outer wrapper below

        def lm_generate(step_text: str) -> str:  # noqa: F811
            nonlocal _first_turn
            content = (intro + "\n\n" + step_text) if _first_turn else step_text
            _first_turn = False
            _conv_history.append({"role": "user", "content": content})
            response = discussion_backend.generate_chat(_conv_history)
            _conv_history.append({"role": "assistant", "content": response})
            if _disc_tok is not None:
                full_prompt = "\n\n".join(m["content"] for m in _conv_history[:-1])
                prompt_token_lengths.append(_disc_tok(full_prompt))
                completion_token_lengths.append(_disc_tok(response))
            return response

        gkw = dict(gkw)
        gkw["use_step_prompt"] = True
        tokenize_fn = None  # disable outer wrapper; counting handled inside above

    # ── Few-shot warmup (few_shot_discussion_cot) ─────────────────────────────
    # Route warmup games through the REAL lm_generate (so all Weave ops, token
    # tracking, and _native_history updates fire identically to a real LM game).
    # The only substitution is at the lowest level: discussion_backend._backend.generate_chat
    # is temporarily replaced with the rb_listener verbalizer for each warmup game.
    if n_few_shot_games > 0 and discussion_backend is not None and lm_generate is not None and not done:
        from symbolic_behaviour_benchmark.rule_based_agents.verbalize import (
            PosdisHypothesisTracker,
        )
        # Share the tracker with rb_listener when it owns one (HypothesisListenerAgent);
        # otherwise create a standalone tracker for verbalization only.
        if hasattr(rb_listener, "tracker"):
            tracker = rb_listener.tracker
        else:
            tracker = PosdisHypothesisTracker()
        listener_state_cap: list = []   # populated by run_one_game before lm_generate fires
        feedback_cap: list = []         # populated by run_one_game during the feedback step
        warmup_cache: dict = {}         # holds rb_listener decision+tokens for post-game update

        real_backend_gen_chat = discussion_backend._backend.generate_chat
        _vocab_size       = gkw.get("vocab_size", 6)
        _max_sent_len     = gkw.get("max_sentence_length", 3)

        for warmup_idx in range(n_few_shot_games):
            if done:
                break

            def _warmup_backend_chat(native_msgs: list, _idx=warmup_idx) -> str:
                if not listener_state_cap:
                    return "Answer: 0"
                obs1   = listener_state_cap[0]["obs1"]
                infos1 = listener_state_cap[0]["infos1"]

                listener_latents  = infos1.get("listener_exp_latents")
                listener_exp_text = infos1.get("listener_exp_text")
                msg_tokens        = list(np.asarray(obs1["communication_channel"]).flatten())

                listener_infos = dict(infos1)
                listener_infos["communication_channel"] = int_comm_to_ohe(
                    obs1["communication_channel"], _vocab_size, _max_sent_len
                )
                a1_rb    = rb_listener.next_action(state=obs1, infos=listener_infos)
                decision = int(np.asarray(a1_rb["decision"]).flat[0])

                warmup_cache["msg_tokens"]       = msg_tokens
                warmup_cache["decision"]         = decision
                warmup_cache["listener_latents"] = listener_latents

                return tracker.verbalize_pre_result(
                    game_idx=_idx,
                    msg_tokens=msg_tokens,
                    listener_latents=listener_latents,
                    decision=decision,
                    domain=few_shot_domain,
                    o_centric=few_shot_nbr_object_centric,
                    slim=getattr(rb_listener, "slim", True),
                    inductive=getattr(rb_listener, "inductive", False),
                    listener_exp_text=listener_exp_text,
                )

            discussion_backend._backend.generate_chat = _warmup_backend_chat
            reward, done, obs, infos = run_one_game(
                env, obs, infos, rb_speaker,
                lm_generate, rb_listener,
                listener_state_capture=listener_state_cap,
                feedback_state_capture=feedback_cap,
                **gkw,
            )
            discussion_backend._backend.generate_chat = real_backend_gen_chat

            if warmup_cache:
                # When rb_listener owns the tracker (HypothesisListenerAgent),
                # observe_feedback was already called inside run_one_game — skip it here
                # to avoid double-counting. Fall back to feedback_cap for standalone tracker.
                if feedback_cap and not hasattr(rb_listener, "observe_feedback"):
                    tracker.observe_feedback(
                        warmup_cache["msg_tokens"],
                        feedback_cap[0].get("listener_exp_latents"),
                        feedback_cap[0].get("listener_exp_text"),
                    )
                tracker.update_from_reward(warmup_cache["msg_tokens"], reward)

    # Wrap lm_generate to record prompt and completion token lengths per call.
    if lm_generate is not None and tokenize_fn is not None:
        _orig_generate = lm_generate
        def lm_generate(prompt_text: str) -> str:  # noqa: F811
            completion = _orig_generate(prompt_text)
            prompt_token_lengths.append(tokenize_fn(prompt_text))
            completion_token_lengths.append(tokenize_fn(completion))
            return completion

    context_overflow = False
    game_idx = 0
    try:
        while not done:
            is_test = is_test_mode(infos[0])
            reward, done, obs, infos = run_one_game(
                env, obs, infos, rb_speaker, lm_generate, rb_listener, **gkw
            )
            if is_test:
                test_total += 1
                test_correct += int(reward > 0)
            else:
                train_total += 1
                train_correct += int(reward > 0)
            if rb_listener is not None and hasattr(rb_listener, "update_from_reward"):
                rb_listener.update_from_reward(reward)
            if step_callback is not None:
                step_callback(game_idx, is_test, bool(reward > 0))
            game_idx += 1
    except ContextOverflowError as exc:
        context_overflow = True
        print(f"[ContextOverflow] Episode terminated early: {exc}")

    result: dict = {
        "zsct_acc": test_correct / max(test_total, 1),
        "support_acc": train_correct / max(train_total, 1),
        "n_test": test_total,
        "n_train": train_total,
        "context_overflow": context_overflow,
    }
    # Per-episode stats aggregated over games (individual LM calls).
    if prompt_token_lengths:
        result["prompt_tokens"] = token_stats(prompt_token_lengths)
        result["completion_tokens"] = token_stats(completion_token_lengths)
    return result
