# core/builder/cmp_extraction_agent.py

import json
import asyncio
from langgraph.graph import StateGraph, END
from core.utils.format import correct_json_format
from core.builder.manager.information_manager import InformationExtractor

class CMPExtractionAgent:
    """
    Wardrobe/Styling/Prop extraction agent.

    Design choices:
    - Run three extractions (wardrobe / styling / propitem) and merge records.
    - Reflection logs are plain text list (NOT JSON) to save tokens and be LLM-friendly.
    - best_result only keeps merged results + feedbacks + score.

    Output schema for each item (aligned with DB columns):
        {
            "name": str,
            "category": str,
            "subcategory": str,
            "appearance": str,
            "status": str,
            "character": str,
            "evidence": str | list[str],
            "notes": str
        }
    """

    def __init__(self, config, llm, system_prompt, prompt_loader=None):
        self.config = config
        self.extractor = InformationExtractor(config, llm, prompt_loader=prompt_loader)
        self.system_prompt = system_prompt
        self.score_threshold = self.config.agent.score_threshold
        self.max_retries = self.config.agent.max_retries
        self.graph = self._build_graph()

    # ---------------- LangGraph nodes ----------------

    def extract_cmp(self, state: dict):
        content = state["content"].strip()
        feedbacks = state.get("feedbacks", [])

        # Three parallel extractions (can feed per-type history back if needed)
        w_raw = self.extractor.extract_wardrobe(
            content=content,
            system_prompt=self.system_prompt,
            reflection_results={"previous_results": state.get("wardrobe_results", []), "feedbacks":  feedbacks}
        )
        s_raw = self.extractor.extract_styling(
            content=content,
            system_prompt=self.system_prompt,
            reflection_results={"previous_results": state.get("styling_results", []), "feedbacks": feedbacks}
        )
        p_raw = self.extractor.extract_propitem(
            content=content,
            system_prompt=self.system_prompt,
            reflection_results={"previous_results": state.get("propitem_results", []), "feedbacks": feedbacks}
        )

        # Parse JSON and get "results" arrays
        try:
            w_obj = json.loads(correct_json_format(w_raw))
            wardrobe_results = w_obj["results"] if isinstance(w_obj, dict) and "results" in w_obj else []
        except Exception:
            wardrobe_results = []
        try:
            s_obj = json.loads(correct_json_format(s_raw))
            styling_results = s_obj["results"] if isinstance(s_obj, dict) and "results" in s_obj else []
        except Exception:
            styling_results = []
        try:
            p_obj = json.loads(correct_json_format(p_raw))
            propitem_results = p_obj["results"] if isinstance(p_obj, dict) and "results" in p_obj else []
        except Exception:
            propitem_results = []

        # Merge results (do NOT add a "type" field unless your downstream needs it)
        merged = []
        for item in wardrobe_results:
            merged.append(dict(item))
        for item in styling_results:
            merged.append(dict(item))
        for item in propitem_results:
            merged.append(dict(item))

        # Keep per-type results in state for the next round; final output returns best_result only
        return {
            "content": content,
            "results": merged,
            "wardrobe_results": wardrobe_results,
            "styling_results": styling_results,
            "propitem_results": propitem_results,
            "feedbacks": state.get("feedbacks", []),
            "score": state.get("score", 0),
            "retry_count": state.get("retry_count", 0),
            "best_score": state.get("best_score", 0),
            "best_result": state.get("best_result", {"results": [], "feedbacks": [], "score": 0})
        }

    def reflect_cmp(self, state: dict):
        content = state["content"]
        results = state.get("results", [])
        reflection_results = state.get("reflection_results", {})

        # === Reflection logs use plain text lines (NOT JSON) ===
        # Example line format (English labels):
        # - name=curtain | category=prop | subcategory=curtain | appearance= | status=open |
        #   character=Liu Peiqiang | notes= | evidence="opens the curtain"
        lines = []
        for r in results:
            line = (
                f"- name={r.get('name','')} | "
                f"category={r.get('category','')} | "
                f"subcategory={r.get('subcategory','')} | "
                f"appearance={r.get('appearance','')} | "
                f"status={r.get('status','')} | "
                f"character={r.get('character','')} | "
                f"notes={r.get('notes','')} | "
                f"evidence={json.dumps(r.get('evidence',''), ensure_ascii=False)}"
            )
            lines.append(line)
        logs = "\n".join(lines)

        raw = self.extractor.reflect_cmp_extractions(
            logs=logs,
            content=content,
            system_prompt=self.system_prompt,
            previous_reflection=reflection_results
        )

        # Parse reflection result
        try:
            ref = json.loads(correct_json_format(raw))
        except Exception:
            ref = {"feedbacks": [], "score": 0}

        try:
            score = int(float(ref.get("score", 0)))
        except Exception:
            score = 0

        feedbacks = ref.get("feedbacks", [])

        # Track best merged result only
        best_score = state.get("best_score", 0)
        if score > best_score:
            best_result = {
                "results": results,
                "feedbacks": feedbacks,
                "score": score
            }
            best_score = score
        else:
            best_result = state.get("best_result", {"results": [], "feedbacks": [], "score": 0})

        return {
            "content": content,
            "results": results,
            "wardrobe_results": state.get("wardrobe_results", []),
            "styling_results": state.get("styling_results", []),
            "propitem_results": state.get("propitem_results", []),
            "feedbacks": feedbacks,
            "score": score,
            "retry_count": state.get("retry_count", 0) + 1,
            "best_score": best_score,
            "best_result": best_result,
            "reflection_results": {"feedbacks": feedbacks, "score": score}
        }

    def _score_check(self, state: dict):
        if state["score"] >= self.score_threshold:
            return "good"
        elif state["retry_count"] >= self.max_retries:
            return "giveup"
        else:
            return "retry"

    def _build_graph(self):
        builder = StateGraph(dict)
        builder.add_node("extract_cmp", self.extract_cmp)
        builder.add_node("reflect_cmp", self.reflect_cmp)

        builder.set_entry_point("extract_cmp")
        builder.add_edge("extract_cmp", "reflect_cmp")
        builder.add_conditional_edges("reflect_cmp", self._score_check, {
            "good": END,
            "retry": "extract_cmp",
            "giveup": END
        })
        return builder.compile()

    # ---------------- Public API ----------------

    def run(self, text: str):
        result = self.graph.invoke({
            "content": text,
            "retry_count": 0,
            "best_score": 0,
            "best_result": {"results": [], "feedbacks": [], "score": 0},
            "reflection_results": {}
        })
        return result.get("best_result", {"results": [], "feedbacks": [], "score": 0})

    async def arun(
        self,
        text: str,
        timeout: int = 120,
        max_attempts: int = 3,
        backoff_seconds: int = 30,
    ):
        """
        Async interface with outer timeout/backoff.
        The inner extract-reflect retries are controlled by the graph and score.
        """
        payload = {
            "content": text,
            "retry_count": 0,
            "best_score": 0,
            "best_result": {"results": [], "feedbacks": [], "score": 0},
            "reflection_results": {}
        }

        try:
            coro = self.graph.ainvoke(payload)
            result = await asyncio.wait_for(coro, timeout=timeout)
            return result.get("best_result", result)
        except asyncio.TimeoutError:
            for i in range(1, max_attempts):
                try:
                    await asyncio.sleep(backoff_seconds * i)
                    result = await asyncio.wait_for(self.graph.ainvoke(payload), timeout=timeout)
                    return result.get("best_result", result)
                except asyncio.TimeoutError:
                    continue
            return {"results": [], "feedbacks": [], "score": 0, "error": "timeout"}
        except Exception as e:
            return {"results": [], "feedbacks": [], "score": 0, "error": "timeout"}
