
import json
import traceback
from typing import List, Optional
import logger, prompt, utils
from agent.base import BaseAgent
from agent.critic import CriticAgent
from pydantic import model_validator
from prompt.actor import WRITE_CODE_PROMPT,WRITE_CODE_SYS_PROMPT, MODEL_USAGE_CONSTRAINT
from prompt.task_types import DATA_INPUT_TASK_TYPES,DATA_OUTPUT_TASK_TYPES, TaskType, ParentTaskType
from schema import Plan, Subtask, TaskResult,AttemptDetail
from utils import parse_code, parse_res
from globals import execute_nb_code
from tool.execute_nb_code import ExecuteNbCode

class ActorAgent(BaseAgent):
    """负责解决原子问题的 Agent（重写原子问题、获取数据表结构、执行工具函数）"""

    name: str = "ActorAgent"
    description: str = (
        "负责解决原子问题的 Agent（重写原子问题、获取数据表结构、执行工具函数）"
    )

    task: Subtask
    critic: Optional[CriticAgent] = None
    # Avoid using the global ExecuteNbCode instance as a class-level default
    # because it may contain non-picklable resources (locks). Use None as
    # the default and initialize in a post-init validator to prevent
    # Pydantic from trying to deepcopy it.
    execute_code: ExecuteNbCode | None = None

    @model_validator(mode="after")
    def initialize_actor(self) -> "ActorAgent":
        # Lazy-assign the shared ExecuteNbCode instance if not provided.
        if self.execute_code is None:
            self.execute_code = execute_nb_code
        return self

    def question(self) -> str:
        return self.task.question

    #_write_and_exec_code()
    def act(self,plan) -> TaskResult:
        """
        获得子问题的答案

        """
        logger.trace(f"【Actor Context Store】: {json.dumps(self.context.store, ensure_ascii=False)}")

        counter = 0
        success = False
        max_retry = 3
        attempts_data = {}
        while not success and counter < max_retry:
            plan_status=plan.get_plan_status()
            ### write code ###
            #todo:任务间执行代码的中间变量的传递
            code= self._write_code(counter, plan_status)

            # self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))

            ### execute code ###
            result, success = self.execute_code.run(code)
            logger.info("【代码执行结果】\n"+result)

            # Result Review
            if success and getattr(utils.module_config, 'enable_result_review', False) and self.critic:
                logger.info("【开始结果审查】")
                try:
                    review_data = self.critic.review_result(self.task, result)
                    review_success = review_data.get("success", True)
                    review_reason = review_data.get("reason", "")
                    
                    if not review_success:
                        success = False
                        result += f"\n\n[Result Review Failed]: {review_reason}"
                        logger.warning(f"【结果审查未通过】{review_reason}")
                    else:
                        logger.info(f"【结果审查通过】{review_reason}")
                except Exception as e:
                    logger.warning(f"【结果审查出错】{e}")

            attempts_data[counter] = AttemptDetail(code=code, result=result)
            plan.process_task_result(attempts_data)
            # self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode))

            ### process execution result ###
            counter += 1
        
        
        # 分析最终结果（如果是成功的）
        new_knowledge = {}
        suggested_next_steps = []
        if success and attempts_data and self.critic:
            last_attempt = attempts_data[counter-1]
            try:
                logger.info("【开始结果分析】")
                analysis_data = self.critic.analyze_result(self.task,  last_attempt.result)
                new_knowledge = analysis_data.get("new_knowledge", {})
                suggested_next_steps = analysis_data.get("suggested_next_steps", [])
                logger.info(f"【结果分析完成】获得新知识: {json.dumps(new_knowledge, ensure_ascii=False)}，建议后续步骤: {suggested_next_steps}")
            except Exception as e:
                logger.warning(f"【结果分析失败】{e}")
            
        return TaskResult(
            attempts=attempts_data,
            is_success=success,
            new_knowledge=new_knowledge,
            suggested_next_steps=suggested_next_steps
        )
    
    def _write_code(self, retry_count: int, plan_status: str):
        logger.debug(f"【开始生成代码】 任务{self.task.task_id} 重试次数: {retry_count}")
        
        # 获取累积的关键中间结果
        key_results = self.context.get("key_intermediate_results", {})
        key_results_desc = json.dumps(key_results, ensure_ascii=False, indent=2) if key_results else "无"
        
        # 获取当前任务类型
        task_type_name = self.task.task_type
        # 定义需要模型加载规范的任务类型集合

        is_post_training = False
        if any(kw in task_type_name for kw in ["评估","建模","训练", "预测", "模型", "验证", "evaluate", "predict", "submission"]):
            is_post_training = True
            
        model_constraint_str = MODEL_USAGE_CONSTRAINT if is_post_training else ""

        user_prompt = WRITE_CODE_PROMPT.format(
            user_requirement=self.requirement,
            plan_status=plan_status,
            key_intermediate_results=key_results_desc,
            model_usage_constraint=model_constraint_str
        )
        
        if self.task.task_type in DATA_INPUT_TASK_TYPES:
            # from globals import GLOBAL_MEMORY
            # work_messages = GLOBAL_MEMORY.get_messages(2, roles=["tool"], tool_names=["data_explorer"]) or []
            user_prompt+= f"数据表结构信息(仅展示最多10个文件)：{self.context.get('required_data_schema','无')}\n" if self.context.get('required_data_schema',None) else ""
        elif self.task.task_type in DATA_OUTPUT_TASK_TYPES:
            user_prompt+= f"输出结果要求：{self.context.get('final_output_requirements','无')}\n" if self.context.get('final_output_requirements',None) else ""
            user_prompt+= f"输出示例：{self.context.get('submission_example','无')}\n" if self.context.get('submission_example',None) else ""
        messages = [
            {
                "role": "system",
                "content": WRITE_CODE_SYS_PROMPT,
            },
            {
                "role": "user",
                "content": user_prompt,
            },
        ]
        response = self.llm.ask(messages=messages)
        code = parse_code(response)
        # code = self._normalize_workspace_paths(code)
        return code

    def _normalize_workspace_paths(self, code: str) -> str:
        """
        规范化生成代码中的路径：Notebook 执行目录已是 workspace/，
        因此应移除字符串字面量中的前缀 "workspace/" 或 "workspace\\"，
        以及将独立的字符串字面量 'workspace' 替换为 '.'。
        仅对字符串字面量进行替换，避免误改变量名。
        """
        import re as _re
        s = code or ""
        # 1) 将独立的 'workspace' 或 "workspace" 替换为 '.'
        s = _re.sub(r"([\'\"])\s*workspace\s*\1", r"\1.\1", s)
        # 2) 去掉以 workspace/ 或 workspace\\ 开头的路径前缀（在字符串字面量内）
        s = _re.sub(r"([\'\"])\s*workspace[\\/]+", r"\1", s)
        return s
    

    def get_knowledge(self, *args, **kwargs):
        return super().get_knowledge(*args, **kwargs)
   

    def get_parent_tasks_desc(self) -> list[dict]:
        if not self.has_parent_task():
            return ""
        return [task.to_simple_dict() for task in self.parent_tasks]

    def has_parent_task(self) -> bool:
        return self.parent_tasks and len(self.parent_tasks) > 0
