from typing import List
import logger
from agent.base import BaseAgent
from agent.planner import PlannerAgent
from agent.critic import CriticAgent
from datetime import datetime
import json
from pathlib import Path
import os
import nbformat
from globals import execute_nb_code
from schema import Context
import logger
import utils


class DSagent(BaseAgent):
    """ 数据科学 Agent，负责数据科学任务的启动运行 """

    name: str = "DSagent"
    description: str = (
        "数据科学 Agent，负责数据科学任务的启动运行"
    )
    run_id: str = ""
    complexity: str = "auto"

    def act(self, requirement: str, run_id: str = "", complexity: str = "auto",extra_context: dict = None) -> dict:
        """
        运行数据科学任务

        :param requirement: 任务需求描述
        """
        self.requirement = requirement
        self.run_id = run_id
        self.complexity = (complexity or "auto").lower()
        self.context = Context.from_dict(extra_context or {})
        rq=self.requirement
        # #清空workspace目录
        # workspace_dir = Path("workspace")
        # if workspace_dir.exists() and workspace_dir.is_dir():
        #     for item in workspace_dir.iterdir():
        #         if item.is_file():
        #             item.unlink()
        #         elif item.is_dir():
        #             import shutil
        #             shutil.rmtree(item)
        
        if len(rq)>2000:
            rq=rq[:1000]+"......(已截断)"+rq[-1000:]
        logger.info("【数据科学任务启动】"+rq+(" run_id 为 "+run_id if run_id else "")+f" complexity={self.complexity}")

        # 应用复杂度策略（影响澄清与拆分），执行后恢复
        prev_flags = None
        try:
            if self.complexity == "auto":
                prev_flags = self._apply_complexity_policy(rq)
                if prev_flags[1]:
                    utils.module_config.need_clarification = True
            plan = self._plan_and_act()
            summary = ""
            summary_enabled = getattr(utils.module_config, 'enable_summary', False)
            if summary_enabled:
                from agent.summary import SummaryAgent
                summary_agent = SummaryAgent()
                summary = summary_agent.act(plan=plan,summary_type=getattr(utils.module_config, 'summary_type', 'report'))
                logger.info(f"【任务总结】\n{summary}")
            plan=plan.dict()
        finally:
            if prev_flags is not None:
                self._restore_complexity_policy(prev_flags)
        return plan,summary

    
    
    def _plan_and_act(self) -> dict:
        """
        规划并执行数据科学任务

        :return: 执行结果列表
        """
        critic = CriticAgent(context=self.context)
        plan = PlannerAgent(requirement=self.requirement, run_id=self.run_id, context=self.context, critic=critic).act()
        save_history=utils.module_config.save_history if utils.module_config else True
        if save_history:
            self.save_history(plan)
        logger.success("【代码已保存】")
        return plan

    def save_history(self,plan, save_dir: str = ""):
        if self.run_id == "":
            self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        if not save_dir:
            save_path = Path("output/results") / f"{self.run_id}"
        else:
            save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)
        
        plan= plan.dict()
        logger.success("【任务规划已保存】")
        # logger.trace("【任务规划已保存】\n"+str(plan))
        
        with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file:
            json.dump(plan, plan_file, indent=4, ensure_ascii=False)
        self.save_code_file(name=Path(self.run_id), code_context=execute_nb_code.nb, file_format="ipynb")
    
    def save_code_file(self,name: str, code_context: str, file_format: str = "py") -> None:
        save_dir = Path("output/results") / str(name)
        save_dir.mkdir(parents=True, exist_ok=True)
        file_path = save_dir / f"code.{file_format}"
        if file_format == "py":
            file_path.write_text(code_context + "\n\n", encoding="utf-8")
        elif file_format == "ipynb":
            nbformat.write(code_context, file_path)
        else:
            raise ValueError("Unsupported file format. Please choose 'py', 'json', or 'ipynb'.")
        
    def get_knowledge(self, *args, **kwargs) -> str:
        pass

    def _apply_complexity_policy(self, rq: str):
        """根据 complexity 调整 utils.module_config 的澄清/拆分策略，返回旧值便于恢复。

        auto 模式：调用一次 LLM 让其输出 JSON 配置，不再使用本地启发式；
        若 LLM 判定失败则回退为 simple。
        """
        from llm import LLM
        m = utils.module_config or utils.load_module_config()
        prev = (bool(getattr(m, 'need_further_split', False)), bool(getattr(m, 'need_clarification', False)))

        comp = (self.complexity or "auto").lower()
        if comp not in {"auto", "simple", "complex"}:
            comp = "auto"

        if comp in {"simple", "complex"}:
            if comp == "simple":
                m.need_further_split = False
                m.need_clarification = False
            else:
                m.need_further_split = True
                m.need_clarification = True
            logger.info(f"【复杂度策略】显式 complexity={comp} -> split={m.need_further_split}, clarify={m.need_clarification}")
            return prev

        # auto: LLM 判定
        try:
            system_prompt = (
                "你是任务复杂度分类助手。根据用户的自然语言数据科学需求，判断任务复杂度以及是否需要进一步澄清需求或进行数据探索。\n"
                "请返回 JSON 格式，包含以下字段：\n"
                "- complexity: 'simple' 或 'complex'\n"
                "- need_clarification: boolean, true 表示需要需求澄清或数据字段及文件含义探索，false 表示不需要\n"
                "定义：simple=单文件或已给出数据路径，且仅需基本统计/可视化；complex=需要多步骤特征工程/建模/评估/多数据源/调参等。\n"
                "仅返回 JSON 字符串，不要包含 Markdown 标记。"
            )
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": rq},
            ]
            resp = self.llm.ask(messages=messages, tools=[])
            content = resp.choices[0].message.content.strip() if hasattr(resp, 'choices') else ""
            
            # 清理可能的 markdown 标记
            if content.startswith("```json"):
                content = content[7:]
            if content.startswith("```"):
                content = content[3:]
            if content.endswith("```"):
                content = content[:-3]
            content = content.strip()

            try:
                result = json.loads(content)
                comp_result = result.get("complexity", "simple").lower()
                need_clarification = result.get("need_clarification", False)
            except json.JSONDecodeError:
                logger.warning(f"【复杂度策略(auto)】JSON 解析失败，内容: {content}")
                # 尝试简单的字符串匹配回退
                if "complex" in content.lower():
                    comp_result = "complex"
                    need_clarification = True
                else:
                    comp_result = "simple"
                    need_clarification = False

            if comp_result == "complex":
                m.need_further_split = True
            else:
                m.need_further_split = False
            m.need_clarification = need_clarification
            
            logger.info(f"【复杂度策略(auto)】LLM 判定为 {comp_result}, need_clarification={need_clarification} -> split={m.need_further_split}, clarify={m.need_clarification}")
        except Exception as e:
            logger.warning(f"【复杂度策略(auto)】LLM 判定失败，回退 simple: {e}")
            m.need_further_split = False
            m.need_clarification = False
        return prev

    def _restore_complexity_policy(self, prev_flags):
        import utils
        m = utils.module_config or utils.load_module_config()
        m.need_further_split, m.need_clarification = prev_flags
