# Copyright (c) 2026
# Research-grade: Ver@K retry interaction with verifier feedback.
#
# Key properties:
# - Multi-turn episode with up to K assistant attempts
# - Verifier check after each attempt
# - Early stop on success (standard Ver@K objective)
# - Feedback message includes history of wrong attempts (configurable)
# - Stores rich metadata per turn so you can later implement per-turn reward shaping

from __future__ import annotations

import hashlib
import importlib
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4

from verl.interactions.base import BaseInteraction

# Optional: reuse verl's GSM8K verifier if available
try:
    from verl.utils.reward_score import gsm8k as gsm8k_reward
except Exception:  # pragma: no cover
    gsm8k_reward = None


logger = logging.getLogger(__name__)
_file_handler: Optional[logging.Handler] = None
_log_samples_seen: int = 0


def _content_to_text(content: Any) -> str:
    """
    Robustly coerce message["content"] to a string.

    In some chat formats, "content" can be:
      - str
      - list of segments, e.g. [{"type":"text","text":"..."}]
    """
    if content is None:
        return ""
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts: List[str] = []
        for seg in content:
            if isinstance(seg, str):
                parts.append(seg)
            elif isinstance(seg, dict):
                if "text" in seg:
                    parts.append(str(seg["text"]))
                elif "content" in seg:
                    parts.append(str(seg["content"]))
        return "".join(parts)
    return str(content)


def _extract_last_assistant_text(messages: List[Dict[str, Any]]) -> str:
    """
    Extract last assistant textual content from message history.
    """
    for item in reversed(messages):
        if item.get("role") == "assistant":
            return _content_to_text(item.get("content", ""))
    return ""


def _count_assistant_turns(messages: List[Dict[str, Any]]) -> int:
    return sum(1 for m in messages if m.get("role") == "assistant")


def _import_by_path(fn_path: str):
    """
    Import a symbol from a 'module.submodule:attr' or 'module.submodule.attr' path.
    """
    if ":" in fn_path:
        mod_path, attr = fn_path.split(":", 1)
    else:
        mod_path, attr = fn_path.rsplit(".", 1)
    mod = importlib.import_module(mod_path)
    return getattr(mod, attr)


class VerKRetryInteraction(BaseInteraction):
    """
    Multi-turn retry interaction that:
      - evaluates each assistant attempt with a verifier
      - provides feedback (and optionally history) on failure
      - terminates on success or after K attempts

    This implements the standard Ver@K objective when you:
      - terminate immediately on first correct attempt, and
      - compute RL reward from final attempt correctness.
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)

        # Default maximum attempts; can be overridden per-sample via start_interaction(max_attempts=...)
        self.default_max_attempts: int = int(config.get("max_attempts", 5))
        if self.default_max_attempts <= 0:
            raise ValueError("config.max_attempts must be >= 1")

        # Verifier configuration
        # - "gsm8k_strict": uses verl.utils.reward_score.gsm8k.compute_score (expects #### format)
        # - "exact_match": compares extracted answer to ground_truth (string)
        # - "callable": user provides a function path that returns (is_correct: bool, feedback: str | None)
        self.verifier_type: str = str(config.get("verifier_type", "exact_match"))
        self.verifier_fn_path: Optional[str] = config.get("verifier_fn_path", None)
        self._verifier_fn = _import_by_path(self.verifier_fn_path) if self.verifier_fn_path else None

        # Answer extraction (used for exact_match and for nicer logging)
        # Supported:
        # - "gsm8k_hashes": extract after '####'
        # - "boxed": extract LaTeX \\boxed{...}
        # - "last_line": take last non-empty line
        self.answer_extraction: str = str(config.get("answer_extraction", "boxed"))

        # Feedback templates
        fb = config.get("feedback", {}) or {}
        self.correct_msg: str = str(
            fb.get("correct", "Correct. Stop.")
        )
        self.incorrect_msg: str = str(
            fb.get(
                "incorrect",
                "Incorrect.\n"
                "Reflect on your previous attempt(s) and try again.\n"
                "IMPORTANT: Put your final answer as `\\boxed{<answer>}`.",
            )
        )
        self.max_attempts_msg: str = str(
            fb.get(
                "max_attempts_reached",
                "Incorrect and maximum attempts reached. Stop.",
            )
        )

        # History injection
        self.include_history: bool = bool(config.get("include_history", True))
        self.history_max_chars: int = int(config.get("history_max_chars", 2000))

        # Store extra metadata for future reward shaping
        self.track_turn_metadata: bool = bool(config.get("track_turn_metadata", True))
        # Optional logging of attempts for debugging
        self.log_attempts: bool = bool(config.get("log_attempts", False))
        self.log_max_chars: int = int(config.get("log_max_chars", 300))
        self.log_attempts_num_samples: int = int(config.get("log_attempts_num_samples", 0))
        self.log_attempts_max_per_instance: int = int(config.get("log_attempts_max_per_instance", 2))
        self.log_file: Optional[str] = config.get("log_file")
        self.log_file_overwrite: bool = bool(config.get("log_file_overwrite", True))
        self._maybe_attach_file_handler()

        # Per-instance state
        # instance_id -> dict with:
        #   query, ground_truth, max_attempts, attempts(list), turn_scores(list), solved(bool), solved_at(int)
        self._instances: Dict[str, Dict[str, Any]] = {}

    async def start_interaction(
        self,
        instance_id: Optional[str] = None,
        query: Optional[str] = None,
        ground_truth: Any = None,
        max_attempts: Optional[int] = None,
        **kwargs,
    ) -> str:
        if instance_id is None:
            instance_id = str(uuid4())

        eff_max_attempts = int(max_attempts) if max_attempts is not None else self.default_max_attempts
        if eff_max_attempts <= 0:
            raise ValueError("max_attempts must be >= 1")

        # Preserve any extra verifier context (e.g., maze spec) for callable verifiers.
        extra_info = kwargs.get("extra_info")
        if extra_info is None or not isinstance(extra_info, dict):
            extra_info = {}
        maze = kwargs.get("maze")
        if isinstance(maze, dict) and "maze" not in extra_info:
            extra_info["maze"] = maze

        self._instances[instance_id] = {
            "query": query,
            "ground_truth": ground_truth,
            "max_attempts": eff_max_attempts,
            "attempts": [],          # list of dicts per attempt
            "turn_scores": [],       # list[float] per attempt (not necessarily used by trainer today)
            "solved": False,
            "solved_at": None,
            "last_score": 0.0,
            "last_is_correct": False,
            "extra_info": extra_info,
            "log_attempts_enabled": self._pick_logging(instance_id),
            "log_attempts_logged": 0,
        }
        self._maybe_log_init(self._instances[instance_id])
        return instance_id

    async def generate_response(
        self,
        instance_id: str,
        messages: List[Dict[str, Any]],
        **kwargs,
    ) -> Tuple[bool, str, float, Dict[str, Any]]:
        """
        Called after each assistant generation. We verify and return:
          (should_terminate, user_feedback, score, metadata)
        """
        if instance_id not in self._instances:
            raise KeyError(f"Unknown instance_id={instance_id}")

        st = self._instances[instance_id]
        max_attempts = int(st["max_attempts"])

        attempt_idx = _count_assistant_turns(messages)  # 1-based index after first assistant message appears
        assistant_text = _extract_last_assistant_text(messages)

        extracted_answer = self._extract_answer(assistant_text)
        is_correct, score, verifier_feedback = self._verify(
            query=st.get("query"),
            assistant_text=assistant_text,
            extracted_answer=extracted_answer,
            ground_truth=st.get("ground_truth"),
            attempt_idx=attempt_idx,
            max_attempts=max_attempts,
            messages=messages,
            extra_info=st.get("extra_info"),
        )

        st["last_score"] = float(score)
        st["last_is_correct"] = bool(is_correct)

        # record attempt info
        attempt_rec = {
            "attempt_idx": attempt_idx,
            "assistant_text": assistant_text,
            "extracted_answer": extracted_answer,
            "is_correct": bool(is_correct),
            "score": float(score),
        }
        if verifier_feedback is not None:
            attempt_rec["verifier_feedback"] = verifier_feedback

        st["attempts"].append(attempt_rec)
        st["turn_scores"].append(float(score))

        metadata: Dict[str, Any] = {}
        if self.track_turn_metadata:
            metadata = {
                "attempt_idx": attempt_idx,
                "max_attempts": max_attempts,
                "is_correct": bool(is_correct),
                "score": float(score),
                "extracted_answer": extracted_answer,
                # For future step-wise reward shaping:
                "turn_scores": list(st["turn_scores"]),
            }

        # Termination logic (standard Ver@K retry)
        if is_correct:
            st["solved"] = True
            st["solved_at"] = attempt_idx
            self._maybe_log_attempt(
                st=st,
                attempt_idx=attempt_idx,
                max_attempts=max_attempts,
                is_correct=True,
                score=score,
                extracted_answer=extracted_answer,
                assistant_text=assistant_text,
                reason=None,
            )
            return True, self.correct_msg, 1.0, metadata

        if attempt_idx >= max_attempts:
            # exhausted attempts
            self._maybe_log_attempt(
                st=st,
                attempt_idx=attempt_idx,
                max_attempts=max_attempts,
                is_correct=False,
                score=score,
                extracted_answer=extracted_answer,
                assistant_text=assistant_text,
                reason="max_attempts",
            )
            return True, self.max_attempts_msg, 0.0, metadata

        # Otherwise continue with feedback prompting a retry
        self._maybe_log_attempt(
            st=st,
            attempt_idx=attempt_idx,
            max_attempts=max_attempts,
            is_correct=False,
            score=score,
            extracted_answer=extracted_answer,
            assistant_text=assistant_text,
            reason=None,
        )
        feedback = self._build_incorrect_feedback(st, verifier_feedback=verifier_feedback)
        return False, feedback, 0.0, metadata

    async def calculate_score(self, instance_id: str, **kwargs) -> float:
        """
        Some versions of verl may not use this for final RL reward (see issue #2540),
        but we keep it consistent and correct for future reward shaping / debugging.
        """
        if instance_id not in self._instances:
            raise KeyError(f"Unknown instance_id={instance_id}")
        return float(self._instances[instance_id].get("last_score", 0.0))

    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
        self._instances.pop(instance_id, None)

    # --------------------------
    # Verifier + extraction logic
    # --------------------------

    def _extract_answer(self, assistant_text: str) -> str:
        text = assistant_text.strip()

        if self.answer_extraction == "gsm8k_hashes":
            # Extract after ####
            m = re.search(r"####\s*(.+)\s*$", text, flags=re.MULTILINE)
            return self._normalize_extracted_answer(m.group(1)) if m else ""

        if self.answer_extraction == "boxed":
            matches = re.findall(r"\\boxed\s*\{([^}]*)\}", text)
            if not matches:
                return ""
            return self._normalize_extracted_answer(matches[-1])

        if self.answer_extraction == "answer_tag":
            matches = re.findall(r"<answer>\s*(.*?)\s*</answer>", text, flags=re.IGNORECASE | re.DOTALL)
            if not matches:
                return ""
            cleaned = re.sub(r"\s+", "", matches[-1])
            return self._normalize_extracted_answer(cleaned)

        if self.answer_extraction == "last_line":
            lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
            return lines[-1] if lines else ""

        # fallback
        return ""

    @staticmethod
    def _normalize_extracted_answer(answer: str) -> str:
        return answer.replace(",", "").replace("$", "").strip()

    def _verify(
        self,
        *,
        query: Optional[str],
        assistant_text: str,
        extracted_answer: str,
        ground_truth: Any,
        attempt_idx: int,
        max_attempts: int,
        messages: List[Dict[str, Any]],
        extra_info: Optional[Dict[str, Any]] = None,
    ) -> Tuple[bool, float, Optional[str]]:
        """
        Returns:
          (is_correct, score in [0,1], optional_verifier_feedback)

        Notes:
          - We do NOT leak ground_truth in feedback by default.
          - For "callable", your function may choose to produce richer feedback.
        """
        vt = self.verifier_type

        if vt == "gsm8k_strict":
            if gsm8k_reward is None:
                raise ImportError(
                    "verifier_type='gsm8k_strict' requires verl.utils.reward_score.gsm8k to be importable."
                )
            # Score expects full assistant_text in GSM8K format (#### answer)
            sc = float(
                gsm8k_reward.compute_score(
                    assistant_text,
                    ground_truth,
                    method="strict",
                    format_score=0.0,
                    score=1.0,
                )
            )
            return (sc >= 1.0), sc, None

        if vt == "exact_match":
            pred = str(extracted_answer).strip()
            gt = "" if ground_truth is None else str(ground_truth).strip()
            sc = 1.0 if (pred != "" and pred == gt) else 0.0
            return (sc >= 1.0), sc, None

        if vt == "callable":
            if self._verifier_fn is None:
                raise ValueError(
                    "verifier_type='callable' requires config.verifier_fn_path"
                )

            out = self._verifier_fn(
                query=query,
                attempt_text=assistant_text,
                extracted_answer=extracted_answer,
                ground_truth=ground_truth,
                attempt_idx=attempt_idx,
                max_attempts=max_attempts,
                messages=messages,
                extra_info=extra_info,
            )

            # Supported return formats:
            #  - bool
            #  - (bool, feedback_str_or_none)
            #  - (bool, score_float, feedback_str_or_none)
            if isinstance(out, bool):
                return out, (1.0 if out else 0.0), None
            if isinstance(out, tuple) and len(out) == 2:
                is_ok, fb = out
                return bool(is_ok), (1.0 if is_ok else 0.0), (None if fb is None else str(fb))
            if isinstance(out, tuple) and len(out) == 3:
                is_ok, sc, fb = out
                return bool(is_ok), float(sc), (None if fb is None else str(fb))

            raise TypeError(
                "Callable verifier must return bool, (bool, feedback), or (bool, score, feedback)."
            )

        raise ValueError(f"Unknown verifier_type={vt}")

    def _build_incorrect_feedback(self, st: Dict[str, Any], verifier_feedback: Optional[str]) -> str:
        """
        Build the user message after an incorrect attempt.
        Includes history of wrong attempts if configured.
        """
        base = self.incorrect_msg

        if verifier_feedback:
            base = f"{base}\n\nVerifier feedback:\n{verifier_feedback}"

        if not self.include_history:
            return base

        attempts = st.get("attempts", [])
        if not attempts:
            return base

        # Build compact history of prior wrong attempts
        hist_lines: List[str] = []
        for a in attempts:
            if a.get("is_correct"):
                continue
            idx = a.get("attempt_idx", "?")
            ans = a.get("extracted_answer", "")
            # Keep it small; we don't want to explode context
            snippet = (a.get("assistant_text", "") or "").strip()
            snippet = re.sub(r"\s+", " ", snippet)
            if len(snippet) > 500:
                snippet = snippet[:500] + " …"
            hist_lines.append(f"- Attempt {idx}: extracted_answer={ans!r}; snippet={snippet}")

        hist = "\n".join(hist_lines).strip()
        if not hist:
            return base

        # Truncate history block if too long
        if self.history_max_chars > 0 and len(hist) > self.history_max_chars:
            hist = hist[: self.history_max_chars] + "\n… (history truncated)"

        return f"{base}\n\nPrevious wrong attempts (for reflection):\n{hist}"

    def _format_snippet(self, text: str) -> str:
        snippet = re.sub(r"\s+", " ", (text or "").strip())
        if self.log_max_chars > 0 and len(snippet) > self.log_max_chars:
            return snippet[: self.log_max_chars] + " ..."
        return snippet

    def _maybe_attach_file_handler(self) -> None:
        """
        Attach a file handler once so attempt logs land in a dedicated file.
        """
        global _file_handler
        if not self.log_file or _file_handler is not None:
            return
        log_dir = os.path.dirname(self.log_file)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
        mode = "w" if self.log_file_overwrite else "a"
        handler = logging.FileHandler(self.log_file, mode=mode)
        handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
        _file_handler = handler

    def _pick_logging(self, instance_id: str) -> bool:
        if not self.log_attempts:
            return False
        # Fixed-count sampling takes precedence
        if self.log_attempts_num_samples > 0:
            global _log_samples_seen
            if _log_samples_seen >= self.log_attempts_num_samples:
                return False
            _log_samples_seen += 1
            return True
        return False

    def _maybe_log_attempt(
        self,
        *,
        st: Dict[str, Any],
        attempt_idx: int,
        max_attempts: int,
        is_correct: bool,
        score: float,
        extracted_answer: str,
        assistant_text: str,
        reason: Optional[str],
    ) -> None:
        if not st.get("log_attempts_enabled", False):
            return
        max_logs = self.log_attempts_max_per_instance
        if max_logs > 0 and st.get("log_attempts_logged", 0) >= max_logs:
            return
        st["log_attempts_logged"] = st.get("log_attempts_logged", 0) + 1
        logger.info(
            "ver_k_retry attempt %s/%s correct=%s score=%.3f answer=%r ground_truth=%r reason=%s text=%s",
            attempt_idx,
            max_attempts,
            bool(is_correct),
            float(score),
            extracted_answer,
            st.get("ground_truth"),
            (reason or ""),
            self._format_snippet(assistant_text),
        )

    def _maybe_log_init(self, st: Dict[str, Any]) -> None:
        if not st.get("log_attempts_enabled", False):
            return
        max_logs = self.log_attempts_max_per_instance
        if max_logs > 0 and st.get("log_attempts_logged", 0) >= max_logs:
            return
        st["log_attempts_logged"] = st.get("log_attempts_logged", 0) + 1
        logger.info(
            "ver_k_retry init max_attempts=%s verifier=%s answer_extraction=%s",
            st.get("max_attempts"),
            self.verifier_type,
            self.answer_extraction,
        )
