"""DSPy-based prompt strategies: zero-shot CoT and few-shot CoT for the listener task."""
from __future__ import annotations

import re

try:
    import dspy
except ImportError as e:
    raise ImportError("pip install dspy>=3.0  (or: pip install 'meta-rg-s2b[cot]')") from e

try:
    import weave as _weave
    _weave_op = _weave.op
except ImportError:
    def _weave_op(fn=None, *, name=None, **_kw):  # no-op when weave is absent
        return (lambda f: f) if fn is None else fn


@_weave_op(name="llm_call")
def _log_llm_call(prompt_text: str, content: str) -> str:
    """Weave logging sentinel — only primitive strings; no self, no _Resp."""
    return content


@_weave_op(name="discussion_chat_call")
def _log_discussion_chat_call(messages: list, response: str) -> str:
    """Weave trace for native multi-turn discussion_cot — captures full CoT reasoning."""
    return response

from meta_rg.backends.base import BaseBackend


# ── DSPy Signatures ────────────────────────────────────────────────────────────

class ListenerSignature(dspy.Signature):
    """You are a listener agent in a referential game. Study the training examples
    to understand what the speaker's messages mean, then select the correct object."""

    game_description: str = dspy.InputField(
        desc="Game setup, training examples, and the query to answer"
    )
    answer: str = dspy.OutputField(
        desc=(
            "Decision integer followed by 3 communication token integers, "
            "space-separated. Example: '0 2 3 1'"
        )
    )


class ListenerDiscussionSignature(dspy.Signature):
    """Listener in a referential game played over multiple rounds.
    Use the conversation history to learn the speaker's private code across games,
    then decide whether your current stimulus shares the same latent meaning."""

    history: dspy.History = dspy.InputField(
        desc="Alternating user/assistant turns from prior games in this episode"
    )
    game_step: str = dspy.InputField(
        desc="Current game: your stimulus, the speaker's message, and the question"
    )
    answer: str = dspy.OutputField(
        desc="Your decision as a single integer: 0 if same latent meaning, 1 if different. Example: 0"
    )


# ── BackendLM: wraps BaseBackend as a DSPy-compatible LM ─────────────────────

class BackendLM(dspy.BaseLM):
    def __init__(self, backend: BaseBackend) -> None:
        # Sync DSPy's max_tokens with the backend's actual limit so truncation
        # warnings and any DSPy-side budget logic use the right number.
        max_tokens = getattr(backend, "max_new_tokens", None) or getattr(backend, "max_tokens", 1000)
        super().__init__(model="custom-backend", model_type="chat",
                         cache=False, max_tokens=max_tokens)
        self._backend = backend
        self._tokenize_fn = getattr(backend, "tokenize_fn", None)
        self.prompt_token_lengths: list[int] = []
        self.completion_token_lengths: list[int] = []

    def to_dict(self) -> dict:
        # dictify() in weave's DSPy callback checks to_dict() first; returning a compact
        # dict here prevents it from walking self.history, which stores the full messages
        # list of every LM call and grows O(N²) over a long discussion_cot episode,
        # overflowing weave's 3.6 MB trace-input limit.
        # Note: BaseLM stores max_tokens inside self.kwargs, not as self.max_tokens.
        # __class__ is required by get_op_name_for_callback in dspy_utils.py.
        return {
            "__class__": {
                "module": self.__class__.__module__,
                "qualname": self.__class__.__qualname__,
                "name": self.__class__.__name__,
            },
            "model": self.model,
            "model_type": self.model_type,
            "max_tokens": self.kwargs.get("max_tokens"),
            "backend": type(self._backend).__name__,
        }

    def reset_token_stats(self) -> None:
        self.prompt_token_lengths = []
        self.completion_token_lengths = []

    def _track_tokens(self, prompt_text: str, content: str) -> None:
        if self._tokenize_fn is not None:
            self.prompt_token_lengths.append(self._tokenize_fn(prompt_text))
            self.completion_token_lengths.append(self._tokenize_fn(content))

    def forward(self, prompt=None, messages=None, **kwargs):
        # Apply optional max_tokens override
        max_tokens = kwargs.get("max_tokens")
        _orig_max = None
        if max_tokens is not None and hasattr(self._backend, "max_new_tokens"):
            _orig_max = self._backend.max_new_tokens
            self._backend.max_new_tokens = max_tokens

        try:
            if messages:
                # For multi-turn conversations use the native chat template when available.
                chat_msgs = [
                    m for m in messages
                    if isinstance(m, dict) and m.get("role") and m.get("content")
                ]
                prompt_text = "\n\n".join(m["content"] for m in chat_msgs)
                if chat_msgs and hasattr(self._backend, "generate_chat"):
                    content = self._backend.generate_chat(chat_msgs)
                else:
                    content = self._backend.generate(prompt_text)
            else:
                prompt_text = prompt or ""
                content = self._backend.generate(prompt_text)
        finally:
            if _orig_max is not None:
                self._backend.max_new_tokens = _orig_max

        self._track_tokens(prompt_text, content)
        _log_llm_call(prompt_text=prompt_text, content=content)
        return _make_response(content, prompt_text)


def _make_response(content: str, prompt: str):
    """Minimal OpenAI-compatible response object for DSPy.

    DSPy calls dict(response.usage), so _Usage must support the mapping
    protocol (keys() + __getitem__).
    """
    p_tok = len(prompt.split())
    c_tok = len(content.split())

    class _Usage:
        def __init__(self) -> None:
            self.prompt_tokens = p_tok
            self.completion_tokens = c_tok
            self.total_tokens = p_tok + c_tok

        def keys(self):
            return ("prompt_tokens", "completion_tokens", "total_tokens")

        def __getitem__(self, k: str):
            return getattr(self, k)

    class _Msg:
        def __init__(self, c: str) -> None:
            self.content = c

    class _Choice:
        def __init__(self, c: str) -> None:
            self.message = _Msg(c)

    class _Resp:
        model = "custom-backend"

        def __init__(self, c: str) -> None:
            self.usage = _Usage()
            self.choices = [_Choice(c)]

    return _Resp(content)


# ── Few-shot demos ─────────────────────────────────────────────────────────────

_DEMOS: list[dspy.Example] = [
    dspy.Example(
        game_description=(
            "Training: Message [2 3 1] → object 0 selected; "
            "Message [1 3 1] → object 0; Message [2 1 1] → object 0. "
            "Query: Message [2 3 1]."
        ),
        reasoning=(
            "Position 0 is 2 and position 1 is 3 whenever object 0 is chosen. "
            "The query [2 3 1] matches the training pattern exactly. "
            "Select object 0 and echo the message as communication tokens."
        ),
        answer="0 2 3 1",
    ).with_inputs("game_description"),
    dspy.Example(
        game_description=(
            "Training: Message [3 1 4] → object 0 selected; "
            "Message [2 1 4] → object 0. "
            "Query: Message [3 1 4]."
        ),
        reasoning=(
            "Positions 1 and 2 are constant (1, 4). Position 0 varies but "
            "object 0 is always selected. Query [3 1 4] fits the pattern."
        ),
        answer="0 3 1 4",
    ).with_inputs("game_description"),
]


# ── CotGenerator ──────────────────────────────────────────────────────────────

_REPROMPT_SUFFIX = (
    "\n\nNow state ONLY the final answer as 4 space-separated integers "
    "(decision token1 token2 token3). No explanation. Example: '0 2 3 1'.\nAnswer:"
)


class CotGenerator:
    """
    str→str callable: CoT forward pass with re-prompt fallback and error metrics.

    Counters reset with reset_stats() before each episode so run_eval.py can
    attach per-episode truncation/format stats to the episode result dict.

    Error taxonomy
    --------------
    n_truncated            : first CoT call returned empty answer (hit token limit)
                             → re-prompt was issued
    n_adapter_errors       : DSPy AdapterParseError during first call (model output
                             didn't match expected field/JSON format) → re-prompt issued
    n_re_prompt_truncated  : re-prompt also returned no valid integers
                             → default "0 0 0 0" used
    n_format_errors        : final answer (post-fallback) contains < 4 integers
                             (comm tokens will be zero-padded by parse_action)
    """

    def __init__(self, module, backend: BaseBackend, backend_lm: "BackendLM") -> None:
        self._module = module
        self._backend = backend
        self._backend_lm = backend_lm
        self.n_truncated = 0
        self.n_adapter_errors = 0
        self.n_re_prompt_truncated = 0
        self.n_format_errors = 0

    def reset_stats(self) -> None:
        self.n_truncated = 0
        self.n_adapter_errors = 0
        self.n_re_prompt_truncated = 0
        self.n_format_errors = 0
        self._backend_lm.reset_token_stats()

    @property
    def prompt_token_lengths(self) -> list[int]:
        return self._backend_lm.prompt_token_lengths

    @property
    def completion_token_lengths(self) -> list[int]:
        return self._backend_lm.completion_token_lengths

    def __call__(self, prompt_text: str) -> str:
        adapter_error = False
        try:
            pred = self._module(game_description=prompt_text)
            answer = (pred.answer or "").strip()
        except Exception:
            answer = ""
            adapter_error = True
            self.n_adapter_errors += 1

        if not answer and not adapter_error:
            self.n_truncated += 1

        if not answer:
            answer = self._reprompt(prompt_text)
            if not answer:
                self.n_re_prompt_truncated += 1
                answer = "0 0 0 0"

        if len(re.findall(r"\d+", answer)) < 4:
            self.n_format_errors += 1

        return answer

    def _reprompt(self, prompt_text: str) -> str:
        reprompt_text = prompt_text + _REPROMPT_SUFFIX
        orig = self._get_max_tokens()
        self._set_max_tokens(16)
        raw = self._backend.generate(reprompt_text).strip()
        self._set_max_tokens(orig)
        self._backend_lm._track_tokens(reprompt_text, raw)
        nums = re.findall(r"\d+", raw)
        return " ".join(str(n) for n in nums[:4]) if nums else ""

    def _get_max_tokens(self) -> int:
        return (getattr(self._backend, "max_new_tokens", None)
                or getattr(self._backend, "max_tokens", 64))

    def _set_max_tokens(self, value: int) -> None:
        if hasattr(self._backend, "max_new_tokens"):
            self._backend.max_new_tokens = value
        elif hasattr(self._backend, "max_tokens"):
            self._backend.max_tokens = value


# ── DiscussionCotBackend ───────────────────────────────────────────────────────

_DISC_REPROMPT_SUFFIX = (
    "\n\nRespond with ONLY 1 integer: 0 (same latent meaning) or 1 (different). "
    "No explanation."
)


class DiscussionCotBackend:
    """
    Native multi-turn backend for discussion mode.

    Replaces the previous DSPy-based approach.  Instead of packing conversation
    history into a single growing user message (via dspy.History), this class
    maintains ``_native_history`` — a list of full role/content turns including
    the model's complete CoT reasoning — and calls ``generate_chat`` directly on
    the backend.

    Why this helps with prefix caching:
      The API sees: [system (stable)] + [prior turns (growing but fixed prefix)]
      + [current user turn (new tokens)].  The system message and all prior turns
      are identical across consecutive calls within an episode, so a provider that
      supports prefix caching can serve them from cache.  The previous DSPy approach
      packed everything into one user message that grew every game, giving the cache
      nothing stable to latch onto.

    Error counters and the stats interface mirror CotGenerator so run_eval.py's
    normalisation block works unchanged.
    """

    _SYSTEM_PROMPT = (
        "You are a listener agent in a referential game played over multiple rounds. "
        "Use the conversation history to learn the speaker's private code across games, "
        "then decide whether the current stimulus shares the same latent meaning.\n\n"
        "Think step by step about the pattern, then end your response with: "
        "\"Answer: 0\" (same latent meaning) or \"Answer: 1\" (different)."
    )

    def __init__(self, backend: BaseBackend, backend_lm: "BackendLM") -> None:
        self._backend = backend
        self._backend_lm = backend_lm
        self._native_history: list[dict] = []
        self.reset_stats()

    # ── stats interface (mirrors CotGenerator) ────────────────────────────────

    def reset_stats(self) -> None:
        self.n_truncated = 0
        self.n_adapter_errors = 0
        self.n_re_prompt_truncated = 0
        self.n_format_errors = 0
        self._native_history = []
        self._backend_lm.reset_token_stats()

    @property
    def prompt_token_lengths(self) -> list[int]:
        return self._backend_lm.prompt_token_lengths

    @property
    def completion_token_lengths(self) -> list[int]:
        return self._backend_lm.completion_token_lengths

    # ── backend-passthrough helpers ───────────────────────────────────────────

    @property
    def tokenize_fn(self):
        return getattr(self._backend, "tokenize_fn", None)

    def set_seed(self, seed: int) -> None:
        self._backend.set_seed(seed)

    def close(self) -> None:
        self._backend.close()

    def generate(self, text: str) -> str:
        return self._backend.generate(text)

    # ── core interface ────────────────────────────────────────────────────────

    def generate_chat(self, messages: list[dict]) -> str:
        """
        Called by run_episode's discussion wrapper with the full role/content history.

        Extracts the current user turn from the last message, prepends the stable
        system prompt and _native_history (prior turns with full CoT responses),
        and calls generate_chat() on the backend directly.  Stores the full raw
        response in _native_history so subsequent games see rich CoT context.
        """
        if not messages:
            return "0"

        user_content = (
            messages[-1].get("content", "")
            if isinstance(messages[-1], dict)
            else str(messages[-1])
        )

        # [system (stable)] + [prior full turns] + [current user turn]
        native_msgs = [{"role": "system", "content": self._SYSTEM_PROMPT}]
        native_msgs.extend(self._native_history)
        native_msgs.append({"role": "user", "content": user_content})

        raw_response = self._backend.generate_chat(native_msgs)
        _log_discussion_chat_call(messages=native_msgs, response=raw_response)
        self._backend_lm._track_tokens(
            "\n\n".join(m["content"] for m in native_msgs), raw_response
        )

        answer = self._extract_answer(raw_response)
        if not answer:
            self.n_truncated += 1
            answer = self._reprompt(native_msgs, raw_response)
            if not answer:
                self.n_re_prompt_truncated += 1
                answer = "0"

        if not re.findall(r"\d+", answer):
            self.n_format_errors += 1

        # Store full response so next game's prefix includes complete CoT context.
        self._native_history.append({"role": "user",      "content": user_content})
        self._native_history.append({"role": "assistant", "content": raw_response})

        return answer

    def _extract_answer(self, text: str) -> str:
        """Return '0' or '1' from the last explicit 'Answer: X' / 'Decision: X' pattern."""
        matches = re.findall(r'(?:[Aa]nswer|[Dd]ecision)[\s:]+([01])\b', text)
        return matches[-1] if matches else ""

    def _reprompt(self, prior_msgs: list[dict], prior_response: str) -> str:
        """Follow-up turn: append assistant CoT + new user request for bare integer."""
        reprompt_msgs = list(prior_msgs)
        reprompt_msgs.append({"role": "assistant", "content": prior_response})
        reprompt_msgs.append({
            "role": "user",
            "content": (
                "Respond with ONLY 1 integer: 0 (same latent meaning) or 1 (different). "
                "No explanation."
            ),
        })
        orig = self._get_max_tokens()
        self._set_max_tokens(8)
        raw = self._backend.generate_chat(reprompt_msgs).strip()
        self._set_max_tokens(orig)
        self._backend_lm._track_tokens(
            "\n\n".join(m["content"] for m in reprompt_msgs), raw
        )
        nums = re.findall(r"\d+", raw)
        return nums[0] if nums else ""

    def _get_max_tokens(self) -> int:
        return getattr(self._backend, "max_new_tokens", None) or getattr(self._backend, "max_tokens", 64)

    def _set_max_tokens(self, value: int) -> None:
        if hasattr(self._backend, "max_new_tokens"):
            self._backend.max_new_tokens = value
        elif hasattr(self._backend, "max_tokens"):
            self._backend.max_tokens = value


# ── Factory ────────────────────────────────────────────────────────────────────

def build_prompt_strategy(
    name: str,
    backend: BaseBackend | None = None,
) -> "CotGenerator | DiscussionCotBackend | None":
    """
    Returns a CotGenerator (str → str callable with error metrics), or None for 'none'.

    The returned callable passes the raw S2B prompt through DSPy's ChainOfThought
    module and returns only the answer field, so parse_action receives clean integers.
    If the first call is truncated it re-prompts the backend directly with a short
    forced-answer suffix (max 16 new tokens).

    backend.set_seed() still works because BackendLM holds a reference to the same
    backend object; seeding it before each episode propagates through DSPy.
    """
    if name == "none":
        return None

    if backend is None:
        raise ValueError("backend is required for CoT strategies")

    backend_lm = BackendLM(backend)
    dspy.configure(lm=backend_lm)

    if name == "zero_shot_cot":
        module = dspy.ChainOfThought(ListenerSignature)
        return CotGenerator(module, backend, backend_lm)
    elif name == "few_shot_cot":
        module = dspy.ChainOfThought(ListenerSignature)
        module.demos = list(_DEMOS)
        return CotGenerator(module, backend, backend_lm)
    elif name in ("discussion_cot", "few_shot_discussion_cot"):
        return DiscussionCotBackend(backend, backend_lm)
    else:
        raise ValueError(f"Unknown prompt strategy: {name!r}")
