from __future__ import annotations

import json
from typing import Literal

from pydantic import Field, model_validator

from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import CodeParser

from di_project.actions.math_output_answer import MathOutputAnswer
from di_project.actions.math_write_code import MathWriteCode
from di_project.actions.execute_nb_code import ExecuteNbCode
from di_project.actions.write_analysis_code import CheckData, WriteAnalysisCode
from di_project.schema import Task, TaskResult
from di_project.strategy.planner import Planner
from di_project.strategy.task_type import TaskType
from di_project.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from di_project.prompts.write_analysis_code import DATA_INFO

REACT_THINK_PROMPT = """
# User Requirement
{user_requirement}
# Context
{context}

Output a json following the format:
```json
{{
    "thoughts": str = "Thoughts on current situation, reflect on how you should proceed to fulfill the user requirement",
    "state": bool = "Decide whether you need to take more actions to complete the user requirement. Return true if you think so. Return false if you think the requirement has been completely fulfilled."
}}
```
"""


class DataInterpreter(Role):
    name: str = "David"
    profile: str = "DataInterpreter"
    auto_run: bool = True
    use_plan: bool = True
    use_reflection: bool = False
    use_experience: bool = False
    execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
    tools: list[str] = []  # Use special symbol ["<all>"] to indicate use of all registered tools
    tool_recommender: ToolRecommender = None
    react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
    max_react_loop: int = 10  # used for react mode
    run_math_prompt: bool = False
    plan_retry_count: int = 1
    answer: str = ''
    csv_result: str = ''
    
    @model_validator(mode="after")
    def set_plan_and_tool(self) -> "Interpreter":
        self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop, auto_run=self.auto_run)
        self.use_plan = (
                self.react_mode == "plan_and_act"
        )  # create a flag for convenience, overwrite any passed-in value
        if self.tools and not self.tool_recommender:
            self.tool_recommender = BM25ToolRecommender(tools=self.tools)
        self.set_actions([WriteAnalysisCode])
        self._set_state(0)
        self.planner = Planner(auto_run=self.auto_run, use_math_prompt=self.run_math_prompt,
                               working_memory=self.rc.working_memory)
        return self
    
    @property
    def working_memory(self):
        return self.rc.working_memory
    
    async def _think(self) -> bool:
        """Useful in 'react' mode. Use LLM to decide whether and what to do next."""
        user_requirement = self.get_memories()[0].content
        context = self.working_memory.get()
        
        if not context:
            # just started the run, we need action certainly
            self.working_memory.add(self.get_memories()[0])  # add user requirement to working memory
            self._set_state(0)
            return True
        
        prompt = REACT_THINK_PROMPT.format(user_requirement=user_requirement, context=context)
        rsp = await self.llm.aask(prompt)
        rsp_dict = json.loads(CodeParser.parse_code(block=None, text=rsp))
        self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant"))
        need_action = rsp_dict["state"]
        self._set_state(0) if need_action else self._set_state(-1)
        
        return need_action
    
    async def _act(self) -> Message:
        """Useful in 'react' mode. Return a Message conforming to Role._act interface."""
        code, _, _ = await self._write_and_exec_code()
        return Message(content=code, role="assistant", cause_by=WriteAnalysisCode)
    
    async def _plan_and_act(self) -> Message:
        """First plan, then execute an action sequence."""
        try:
            if self.run_math_prompt:
                return await self._plan_and_act_math()
            else:
                rsp = await super()._plan_and_act()
                if self.use_experience:
                    from di_project.actions.use_experience import AddNewTrajectories
                    await AddNewTrajectories().run(self.planner)
                await self.execute_code.terminate()
                return rsp
        except Exception as e:
            await self.execute_code.terminate()
            raise e
    
    async def _plan_and_act_math(self) -> Message:
        """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically."""
        
        # create initial plan and update it until confirmation
        goal = self.rc.memory.get()[-1].content  # retreive latest user requirement
        await self.planner.update_plan(goal=goal)
        
        # take on tasks until all finished
        found_answer = False
        plan_count = 0
        while not found_answer and plan_count < self.plan_retry_count:
            task = self.planner.current_task
            task_result = await self._act_on_task(task)
            if not task_result.is_success:
                await self.planner.update_plan()
                plan_count += 1
            found_answer = task_result.is_success
        
        self.working_memory.clear()
        
        msg = Message(content=task_result.result, cause_by=MathWriteCode)
        self.rc.memory.add(msg)
        return msg
    
    async def _act_on_task(self, current_task: Task) -> TaskResult:
        """Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
        # retrieve past tasks for this task
        experiences = ""
        code, exe_result, is_success = await self._write_and_exec_code(experiences=experiences)
        if self.run_math_prompt:
            processor = MathOutputAnswer()
            rsp = Message(content="Runtime solve result: " + exe_result, cause_by=MathWriteCode)
            result = await processor.run([rsp])
            # self-verification
            is_success, answer, csv_result, suggestion = await processor.post_process(plan=self.planner.plan,
                                                                                      answer=result, code=code,
                                                                                      result=exe_result,
                                                                                      execute_nb_code=self.execute_code)
            if not is_success:
                self.planner.working_memory.add(
                    Message(
                        content=exe_result + '\n' + '--------' + '\n' + 'suggestion : ' + suggestion,
                        role="user",
                        cause_by=ExecuteNbCode
                    )
                )
            # for evaluation
            self.answer = answer
            self.csv_result = csv_result
        task_result = TaskResult(code=code, result=exe_result, is_success=is_success)
        return task_result
    
    async def _write_and_exec_code(self, max_retry: int = 3, experiences: str = ""):
        counter = 0
        success = False
        
        # plan info
        plan_status = self.planner.get_plan_status() if self.use_plan else ""
        
        # tool info
        if self.tool_recommender:
            context = (
                self.working_memory.get()[-1].content if self.working_memory.get() else ""
            )  # thoughts from _think stage in 'react' mode
            plan = self.planner.plan if self.use_plan else None
            tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
        else:
            tool_info = ""
        
        # data info
        await self._check_data()
        
        while not success and counter < max_retry:
            ### write code ###
            
            if self.run_math_prompt:
                code, cause_by = await self._write_math_code(counter)
            else:
                code, cause_by = await self._write_code(
                    counter, plan_status, tool_info, experiences=experiences if counter == 0 else ""
                )
            
            self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
            
            ### execute code ###
            result, success = await self.execute_code.run(code)
            print(result)
            
            self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode))
            
            ### process execution result ###
            counter += 1
            
            if not success and counter >= max_retry:
                logger.info("coding failed!")
        
        return code, result, success
    
    async def _write_code(self, counter: int, plan_status: str = "", tool_info: str = "", experiences: str = ""):
        todo = self.rc.todo  # todo is WriteAnalysisCode
        logger.info(f"ready to {todo.name}")
        use_reflection = counter > 0 and self.use_reflection  # only use reflection after the first trial
        
        user_requirement = self.get_memories()[0].content
        
        code = await todo.run(
            user_requirement=user_requirement,
            plan_status=plan_status,
            tool_info=tool_info,
            working_memory=self.working_memory.get(),
            use_reflection=use_reflection,
            experiences=experiences,
        )
        
        return code, todo
    
    async def _write_math_code(self, counter: int):
        context = self.planner.get_last_useful_memories(num=3)
        use_reflection = counter > 0 and self.use_reflection
        cause_by = MathWriteCode
        code = self.working_memory.get_by_action(cause_by) if use_reflection else ""
        code = await MathWriteCode().run(context=context, plan=self.planner.plan,
                                         exec_result=self.working_memory.get(),
                                         code=code,
                                         use_reflection=use_reflection)
        return code, cause_by
    
    async def _check_data(self):
        if (
                not self.use_plan
                or not self.planner.plan.get_finished_tasks()
                or self.planner.plan.current_task.task_type
                not in [
            TaskType.DATA_PREPROCESS.type_name,
            TaskType.FEATURE_ENGINEERING.type_name,
            TaskType.MODEL_TRAIN.type_name,
        ]
        ):
            return
        logger.info("Check updated data")
        code = await CheckData().run(self.planner.plan)
        if not code.strip():
            return
        result, success = await self.execute_code.run(code)
        if success:
            print(result)
            data_info = DATA_INFO.format(info=result)
            self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))
