# PhilAgent/actions/A4_context_handle.py

import os
import re
import json
from typing import ClassVar, Dict, Any, Union
from metagpt.actions import Action
from metagpt.logs import logger

from utils.global_vars import PROMPT_BASE_DIR


class A4_ContextHandle(Action):
    """
    A4: Convert natural language context into formal logic (FOL) premises,
    for use in downstream reasoning tasks (Step 3).
    """

    name: str = "A4ContextHandle"

    # ✅ 各 ablation_type 对应的 prompt 文件名（不含路径）
    PROMPT_PATHS: ClassVar[dict] = {
        "full": "A4_ContextThandle.txt",
        "woStatement": "A4_ContextThandle_woStatement.txt"
    }

    async def run(
        self,
        inputs: Union[Dict[str, Any], None] = None,
        context: str = "",
        ablation_type: str = "full"
    ) -> dict:
        # ✅ 支持 Agent 模式的字典输入
        if inputs:
            question = inputs.get("question","")
            context = inputs.get("context", "")
            ablation_type = inputs.get("ablation_type", ablation_type)

        prompt_template = self._load_prompt_template(ablation_type)
        prompt = prompt_template.format(question, context=context.strip())

        logger.info("==== [A4] Prompt to LLM ====")
        logger.debug(prompt)

        rsp = await self._aask(prompt)
        return self._parse_json_response(rsp)

    @classmethod
    def _load_prompt_template(cls, ablation_type: str) -> str:
        filename = cls.PROMPT_PATHS.get(ablation_type, cls.PROMPT_PATHS["full"])
        path = os.path.join(PROMPT_BASE_DIR, filename)

        try:
            with open(path, "r", encoding="utf-8") as f:
                return f.read()
        except Exception as e:
            raise RuntimeError(f"[A4] Failed to load prompt: {path} | Error: {e}")

    def _parse_json_response(self, rsp: str) -> dict:
        pattern = r"```json(.*?)```"
        match = re.search(pattern, rsp, re.DOTALL)
        try:
            json_str = match.group(1).strip() if match else rsp.strip()
            logger.debug("==== [A4] Parsed JSON ====")
            logger.debug(json_str)
            return json.loads(json_str)
        except Exception as e:
            logger.error(f"[A4] JSON parse failed: {e}")
            return {
                "error": f"Parsing failed: {str(e)}",
                "raw_response": rsp
            }
