from __future__ import annotations

import re
import textwrap
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

STOP_GRADIENT = "STOP_GRADIENT"

BACKWARD_SYSTEM_PROMPT = """\
You are an expert failure analyst for a multi-step AI system.
You will analyze a component's trace. Your goal is to pinpoint exactly WHY it failed.
CRITICAL ANALYSIS STEPS:
1. Did the component strictly follow its instructions? If no -> LOCAL fault.
2. Was the input physically insufficient to produce the desired output?  -> UPSTREAM fault.
3. AVOID "HINDSIGHT BIAS": Do not blame the component for not knowing facts that were not provided in the context.
4. Distinguish between "Format Error", "Hallucination", and "Missing Info".
Output strictly in the requested format.
"""
OPTIMIZER_SYSTEM_PROMPT = textwrap.dedent(
    f"""\
    You are part of an optimization system that improves text (the prompt).
    You will receive feedback and context, and use them to improve the prompt.
    The feedback may be noisy; identify what is important and correct.
    Pay attention to the role description and the context where the prompt is used.
    This is very important: You MUST return the improved prompt only between the required tags.
    """
)

DEFAULT_BACKWARD_OUTPUT_CONSTRAINTS = textwrap.dedent(
    f"""\
    Output exactly TWO sections, in this exact order:

    LOCAL:
    (feedback for improving THIS component prompt only; or leave empty)

    UPSTREAM:
    ({STOP_GRADIENT} OR feedback for upstream components)

    Rules:
    - You MAY use either single-fault attribution OR partial-blame attribution:
      (A) LOCAL-only:
          - Write actionable LOCAL feedback about improving THIS component prompt/behavior.
          - Set UPSTREAM to {STOP_GRADIENT}.
      (B) UPSTREAM-only:
          - Leave LOCAL empty.
          - Write UPSTREAM feedback describing what is missing/broken in <LM_INPUT> and what earlier components should provide/change.
      (C) PARTIAL-BLAME (both):
          - Write actionable LOCAL feedback (how THIS component should behave better given imperfect/noisy inputs).
          - Write UPSTREAM feedback (what earlier components must change so the required info becomes available/correct).
          - In this mode, UPSTREAM MUST NOT be {STOP_GRADIENT}.
    - LOCAL must be generic (no example-specific entities/answers).
    - If the objective is satisfied / "no feedback", leave LOCAL empty and set UPSTREAM to {STOP_GRADIENT}.
    - If UPSTREAM is {STOP_GRADIENT}, output ONLY the token {STOP_GRADIENT} (no punctuation, no extra words).
    - Do NOT output anything outside these two sections.
    """
)


def build_backward_context_prompt(
    *,
    variable_desc: str,
    variable_value: str,
    lm_input: str,
    lm_output: str,
    objective_feedback: str,
    response_desc: str = "the next component input",
    system_prompt: str = "",
    variable_short_max: int = 400,
) -> str:
    """Construct TextGrad-style context for a backward feedback call."""
    variable_short = short_text(variable_value, variable_short_max)
    return textwrap.dedent(
        f"""\
        You will give feedback to a prompt with the following role:
        <ROLE>{variable_desc}</ROLE>

        Here is a conversation with a language model:
        <LM_SYSTEM_PROMPT>{system_prompt}</LM_SYSTEM_PROMPT>
        <LM_INPUT>{lm_input}</LM_INPUT>
        <LM_OUTPUT>{lm_output}</LM_OUTPUT>

        This conversation is part of a larger system. The <LM_OUTPUT> was later used as {response_desc}.
        Treat <LM_OUTPUT> as this component's incremental contribution at this step.
        Attribute responsibility by asking: could changing <LM_OUTPUT> (given the same <LM_INPUT>) reasonably fix the objective?

        <OBJECTIVE_FUNCTION>{objective_feedback}</OBJECTIVE_FUNCTION>

        We are interested in giving feedback to the following span of text:
        <VARIABLE>{variable_short}</VARIABLE>

        Given the above history, route feedback into LOCAL vs UPSTREAM to improve the objective.
        """
    )


@dataclass
class GradientSignal:
    """Incoming textual complaint handed to a node's backward operator."""

    feedback: str
    context: Dict[str, Any]


@dataclass
class BackwardResponse:
    """Result of a node's backward reasoning."""

    local_fix: str
    upstream_grad: str
    debug: Optional[Dict[str, Any]] = None

    def pruned(self) -> "BackwardResponse":
        return BackwardResponse(
            local_fix=self.local_fix, upstream_grad=STOP_GRADIENT, debug=self.debug
        )


def apply_unified_diff(source: str, search_block: str, replace_block: str) -> Tuple[str, bool]:
    """Apply a simple search/replace diff to source.

    Returns:
        tuple[str, bool]: patched text and flag indicating success.
    """
    if search_block and search_block in source:
        return source.replace(search_block, replace_block, 1), True
    return source, False


def merge_context(*contexts: Optional[Dict[str, Any]]) -> Dict[str, Any]:
    """Merge multiple context dictionaries, later entries overriding earlier ones."""
    merged: Dict[str, Any] = {}
    for ctx in contexts:
        if ctx:
            merged.update(ctx)
    return merged


def format_code_complaint(execution_result: Dict[str, Any]) -> str:
    """Create a textual complaint from an executor result payload."""

    def _safe_text(x: Any) -> str:
        if x is None:
            return ""
        if isinstance(x, str):
            return x
        try:
            if isinstance(x, (dict, list, tuple)):
                import json

                return json.dumps(x, ensure_ascii=False)
        except Exception:
            pass
        return str(x)

    execution_result = execution_result or {}
    passed = execution_result.get("passed")
    trace = _safe_text(execution_result.get("trace", "")).strip()
    stderr = _safe_text(execution_result.get("stderr", "")).strip()
    failed_raw = execution_result.get("failed_tests", [])
    if isinstance(failed_raw, (list, tuple)):
        failed = [_safe_text(x).strip() for x in failed_raw if _safe_text(x).strip()]
    else:
        failed = [_safe_text(failed_raw).strip()] if _safe_text(failed_raw).strip() else []

    if passed:
        return "All tests passed. No feedback."

    complaint = ["Feedback: code failed unit tests."]
    if failed:
        complaint.append(f"Failed tests: {', '.join(failed)}.")
    if trace:
        complaint.append(f"Trace: {trace}")
    if stderr and stderr != trace:
        complaint.append(f"Stderr: {stderr}")
    return " ".join(complaint)


def format_qa_complaint(predicted: str, gold: str) -> str:
    """Create textual delta between predicted and gold answers."""
    predicted = (predicted or "").strip()
    gold = (gold or "").strip()
    if predicted.lower() == gold.lower():
        return "Answer correct. No feedback."
    return f"Feedback: predicted answer '{predicted}' differs from gold '{gold}'."


def parse_tagged_block(text: str, start_tag: str, end_tag: str) -> str:
    """Extract content between tags; return empty string if not found."""
    if not text:
        return ""
    if start_tag not in text or end_tag not in text:
        return ""
    return text.split(start_tag, 1)[1].split(end_tag, 1)[0].strip()


def parse_gradient_response(text: str) -> BackwardResponse:
    """Parse a backward response with LOCAL/UPSTREAM sections."""
    if not text:
        return BackwardResponse(local_fix="", upstream_grad=STOP_GRADIENT)

    normalized = text.strip()
    sections = {"local": "", "upstream": ""}
    current = None

    for line in normalized.splitlines():
        upper = line.strip().upper()
        if upper.startswith("LOCAL"):
            current = "local"
            continue
        if upper.startswith("UPSTREAM"):
            current = "upstream"
            continue
        if current:
            sections[current] += line + "\n"

    local_fix = sections["local"].strip()
    upstream_grad = sections["upstream"].strip() or STOP_GRADIENT

    if re.match(r"^\s*STOP[_\s-]*GRADIENT\b", upstream_grad, flags=re.IGNORECASE):
        upstream_grad = STOP_GRADIENT

    return BackwardResponse(local_fix=local_fix, upstream_grad=upstream_grad)


IMPROVED_PROMPT_START = "<IMPROVED_PROMPT>"
IMPROVED_PROMPT_END = "</IMPROVED_PROMPT>"
PROMPT_UPDATE_TEMPLATE = textwrap.dedent(
    """\
    You are the **System Optimizer** for a Compound AI System.
    Your goal is to harden the System Prompt for a specific component <VARIABLE> to make it robust against a batch of observed failures.

    ### INPUT DATA
    1. **Component Role**: <ROLE>{variable_desc}</ROLE>
    2. **Current Prompt**: <VARIABLE>{variable_short}</VARIABLE>
    3. **Failure Analysis (Batch)**:
    <BATCH_FEEDBACK>
    {variable_context}
    </BATCH_FEEDBACK>

    ### OPTIMIZATION MISSION
    You must rewrite the <VARIABLE> prompt. The new prompt must:
    1. **Fix the Root Cause**: Address the recurring patterns in the feedback (e.g., "Hallucination", "Format Violation", "Lazy Reasoning").
    2. **Generalize**: Do NOT overfit to the specific examples in the batch. (e.g., If feedback says "You missed the variable 'x'", do NOT write "Always check for 'x'". Instead, write "Always check for uninitialized variables").
    3. **Enforce Constraints**: Use **Negative Constraints** (e.g., "NEVER output...") to block bad behaviors.
    4. **Preserve Scope**: Do not change the fundamental role of the component. It must still function as a Residual Operator (producing a Delta).

    ### CRITICAL THINKING PROCESS (Internal Monologue)
    1. **Cluster Failures**: Group the feedback into categories: Format, Logic, or Hallucination.
    2. **Draft Rules**: Create 1-2 generic, high-priority rules to prevent these clusters.
    3. **Structure Check**: Ensure the new prompt prioritizes these new rules (e.g., by placing them in a "CRITICAL INSTRUCTIONS" section).
    4. **Sanity Check**: Does the new prompt still strictly enforce the output format? (Loss of format is catastrophic).

    ### OUTPUT INSTRUCTIONS
    - Output the FULLY REWRITTEN prompt between {start_tag} and {end_tag}.
    - The output should be ready to use in production.
    - **Highlight** the new constraints using uppercase or bullet points for visibility.

    {start_tag}{{new_prompt}}{end_tag}
    """
)


def short_text(text: str, max_chars: int = 400) -> str:
    text = (text or "").strip()
    if len(text) <= max_chars:
        return text
    return text[:max_chars] + "..."
