# PhilAgent/actions/A1_construct_semiotic_square.py

import re
import json
import sys
from typing import ClassVar, Optional, Union
from metagpt.context import Context
from metagpt.actions import Action
from metagpt.logs import logger
import os

from utils.global_vars import PROMPT_BASE_DIR

class A1_ConstructSemioticSquare(Action):
    """
    Step 1: Extract opposing concepts and construct the Semiotic Square.
    Supports both manual and agent-style execution.
    """

    name: str = "A1_ConstructSemioticSquare"

    # ✅ 动态设置 prompt 路径（根据 ablation_type 切换）
    PROMPT_PATHS: ClassVar[dict] = {
        "full": "A1_ConstructSemioticSquare.txt",
        "woFOL": "A1_ConstructSemioticSquare_woFOL.txt",
        "woStatement": "A1_ConstructSemioticSquare_woStatement.txt",
    }

    def __init__(self, context: Context, ablation_type: str = "full"):
        super().__init__(context=context)
        if ablation_type not in self.PROMPT_PATHS:
            raise ValueError(f"Invalid ablation_type '{ablation_type}'")
        self.ablation_type = ablation_type
        self.prompt_file = os.path.join(PROMPT_BASE_DIR, self.PROMPT_PATHS[ablation_type])


    async def run(
        self,
        question: Union[str, dict],
        context: Optional[str] = None,
        ablation_type: Optional[str] = None
    ):
        # 🧠 判断是否是 Agent 输入 dict（新版）还是用户传入两个参数（旧版）
        if isinstance(question, dict):
            logger.debug("[A1] Running in Agent mode (dict input)")
            inputs = question
            question = inputs.get("question", "")
            context = inputs.get("context", "")
            ablation_type = inputs.get("ablation_type", ablation_type)
        else:
            logger.debug("[A1] Running in manual mode (str input)")
            question = question or ""
            context = context or ""

        # ✅ 若明确传入 ablation_type，则更新 prompt 路径
        if ablation_type:
            if ablation_type not in self.PROMPT_PATHS:
                raise ValueError(f"[A1] Unsupported ablation_type: {ablation_type}")
            self.ablation_type = ablation_type
            self.prompt_file = os.path.join(PROMPT_BASE_DIR, self.PROMPT_PATHS[self.ablation_type])
            logger.debug(f"[A1] Prompt path switched to {self.prompt_file} based on ablation_type={ablation_type}")


        # 🧾 构造 prompt
        prompt_template = self._load_prompt_template()
        prompt = prompt_template.format(question=question, context=context)

        logger.info("==== [A1] Prompt to LLM ====")
        logger.debug(prompt)

        rsp = await self._aask(prompt)
        square = A1_ConstructSemioticSquare.parse_json(rsp)

        logger.info("==== [A1] Parsed Semiotic Square ====")
        logger.info(json.dumps(square, indent=2, ensure_ascii=False))

        return square


    def _load_prompt_template(self) -> str:
        try:
            with open(self.prompt_file, "r", encoding="utf-8") as f:
                return f.read()
        except FileNotFoundError:
            raise FileNotFoundError(f"Prompt file not found: {self.prompt_file}")
        except Exception as e:
            raise RuntimeError(f"Failed to load prompt template: {e}")

    def parse_json(rsp):
        pattern = r"```json(.*?)```"
        match = re.search(pattern, rsp, re.DOTALL)
        try:
            json_str = match.group(1).strip() if match else rsp.strip()

            # 🔧 自动修复尾随逗号：去除所有 `,` 后跟 `}` 或 `]` 的情况
            json_str = re.sub(r",\s*([}\]])", r"\1", json_str)

            logger.debug("=== [A1] Extracted JSON Payload ===")
            logger.debug(json_str)
            raw = json.loads(json_str)

            # 兼容嵌套格式 { "Semiotic Square": { ... } }
            square = raw.get("Semiotic Square", raw)

            # 映射 unicode key（如 "¬S1"）为标准 ASCII key（如 "not_S1"）
            key_mapping = {
                "¬S1": "not_S1",
                "¬S2": "not_S2"
            }
            for unicode_key, ascii_key in key_mapping.items():
                if unicode_key in square and ascii_key not in square:
                    square[ascii_key] = square[unicode_key]

            # 最终只保留目标结构字段
            result = {
                "S1": square.get("S1", {}),
                "S2": square.get("S2", {}),
                "not_S1": square.get("not_S1", {}),
                "not_S2": square.get("not_S2", {}),
                "S2_type": square.get("S2_type", "")
            }

            return result

        except Exception as e:
            logger.error(f"=== [A1] JSON Parsing Failed: {e} ===")
            return {
                "error": f"Parsing failed: {str(e)}",
                "raw_response": rsp
            }

