"""LLM-based prompt optimizer for TextResNet nodes.

Supports batch-style prompt updates to reduce overfitting to individual samples.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

from core.utils.api import get_llm_output
from core.utils.textresnet import (
    IMPROVED_PROMPT_END,
    IMPROVED_PROMPT_START,
    OPTIMIZER_SYSTEM_PROMPT,
    PROMPT_UPDATE_TEMPLATE,
    STOP_GRADIENT,
    parse_tagged_block,
    short_text,
)


class PromptOptimizer:
    """Wraps a lightweight LLM call to improve component prompts."""

    def __init__(
        self, model: str = "openai/gpt-4o-mini", max_tokens: int = 512, temperature: float = 0.7
    ):
        self.model = model
        self.max_tokens = max_tokens
        self.temperature = temperature

    def improve(
        self,
        variable_desc: str,
        current_prompt: str,
        variable_context: str,
        local_fix: str = "",
        upstream_grad: Optional[str] = STOP_GRADIENT,
    ) -> str:
        # variable_context is expected to include batch feedback text.
        if not (variable_context or "").strip() and not local_fix.strip():
            return current_prompt

        prompt = PROMPT_UPDATE_TEMPLATE.format(
            variable_desc=variable_desc or "component prompt",
            variable_short=short_text(current_prompt, 400),
            current_prompt=current_prompt,
            variable_context=variable_context or "",
            start_tag=IMPROVED_PROMPT_START,
            end_tag=IMPROVED_PROMPT_END,
        )

        updated = get_llm_output(
            message=prompt,
            model=self.model,
            max_new_tokens=self.max_tokens,
            temperature=self.temperature,
            system_prompt=OPTIMIZER_SYSTEM_PROMPT,
        )
        candidate = parse_tagged_block(updated, IMPROVED_PROMPT_START, IMPROVED_PROMPT_END)
        return candidate or current_prompt

    def improve_batch(
        self, *, variable_desc: str, current_prompt: str, feedback_list: List[Dict[str, Any]]
    ) -> str:
        """Batch optimization: summarize multiple feedback items into one optimizer call."""
        if not feedback_list:
            return current_prompt

        parts: List[str] = []
        for i, item in enumerate(feedback_list, 1):
            parts.append(f"--- Example {i} ---")
            ctx = item.get("context", "")
            if ctx:
                parts.append(f"Context Summary:\n{ctx}")
            lf = item.get("local_fix", "")
            if lf:
                parts.append(f"Local Feedback:\n{lf}")
            ug = item.get("upstream_grad", "")
            if ug and ug != STOP_GRADIENT:
                parts.append(f"Upstream Feedback:\n{ug}")
            parts.append("")

        batch_feedback_str = "\n".join(parts).strip()
        return self.improve(
            variable_desc=variable_desc,
            current_prompt=current_prompt,
            variable_context=batch_feedback_str,
        )
