from typing import Literal, Dict, List, Optional
from pydantic import Field
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.actions import Action
from metagpt.environment import Environment
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.utils.common import remove_white_spaces
from metagpt.logs import logger
import json
import asyncio

from .actions import Execute
from .prompts import ASSISTANT_REACT_PROMPT
from .metrics import AgentAnalyzer 
from .utils import custom_json_parser

class ExecuteCode(Action):
    name: str = "ExecuteCode"
    # is_output_answer: bool = False
    
    async def run(self, code: str, nb_executor: ExecuteNbCode):
        try:
            outputs, success = await asyncio.wait_for(nb_executor.run(code), timeout=10)  # 10 seconds timeout
        except asyncio.TimeoutError:
            outputs = ""
            success = False

        cell_index = len(nb_executor.nb.cells) - 1
        return code, outputs, success, cell_index


class ProcessText(Action):
    name: str = "ProcessText"

    async def run(self, text: str, nb_executor: ExecuteNbCode):
        nb_executor.add_markdown_cell(text)
        cell_index = len(nb_executor.nb.cells) - 1

        return text, "Markdown cell added successfully", True, cell_index

class AnalyzeAgents(Action):
    name: str = "AnalyzeAgents"

    async def run(self, env: Environment, current_profiles: Dict[str, Dict]):
        analyzer = AgentAnalyzer(env)
        logger.info(f"Current profiles: {current_profiles}")
        new_profiles = {}
        differentiation_score = analyzer.calculate_role_differentiation()
        alignment_score = analyzer.calculate_overall_task_role_alignment()

        for agent_name, current_profile in current_profiles.items():
            clarity_scores = analyzer.calculate_role_clarity(current_profile)
            
            # Update stored profiles and scores
            new_profiles[agent_name] = {
                'profile': current_profile,
                'scores': {
                    'clarity': clarity_scores['total_score'],
                    'differentiation': differentiation_score,
                    'alignment': alignment_score
                }
            }
        
        return new_profiles

class AssistantAgent(Role):
    name: str = "Assistant"
    profile: str = "Code execution and text processing assistant"
    goal: str = "Execute code, process text, and assist with tasks"
    constraints: str = "Ensure safe code execution and accurate text processing"
    react_mode: Literal["plan_and_act", "react"] = "react"
    phase: Literal["update", "act"] = "update"
    max_react_loop: int = 1
    agent_profiles_evaluation: Dict[str, Dict] = Field(default_factory=dict)
    complete_solution_cell_index: Optional[int] = None
    user_requirement: str = ""
    is_valid_answer: bool = False
    send_to: str = ""
    agent_profiles_evaluation: Dict = Field(default_factory=dict)

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.execute_nb = ExecuteNbCode()
        self.set_actions([ExecuteCode(), ProcessText(), AnalyzeAgents()])
        self._watch([Execute, UserRequirement, AnalyzeAgents])
        self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop)

    async def _observe(self, ignore_memory=False) -> int:
        news = self.rc.msg_buffer.pop_all()
        
        for msg in news:
            if self.phase == "act":
                if "Execute" in msg.cause_by:
                    self.rc.memory.add(msg)
    
        self.rc.news = [n for n in news if n.cause_by in self.rc.watch or self.name in n.send_to]
        self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None

        return len(self.rc.news)

    async def _think(self) -> bool:
        if self.phase == "update":
            self.set_todo(self.actions[2])  # AnalyzeAgents
            return True
        
        context = "\n".join([f"{msg.role}: {msg.content}" for msg in self.get_memories()])

        prompt = ASSISTANT_REACT_PROMPT.format(
            user_requirement=self.user_requirement,
            context=context
        )


        rsp = await self.llm.aask(prompt, response_format= { "type" : "json_object" })
        try:
            rsp_dict = custom_json_parser(rsp)
        except json.JSONDecodeError:
            return False
        

        choice = remove_white_spaces(rsp_dict["action"]).upper()
        if choice == "CODE":
            self.set_todo(self.actions[0])
        elif choice == "TEXT":
            self.set_todo(self.actions[1])
        else:
            return False

        self.action_input = rsp_dict["action_input"]
        logger.info(f"action_input: {self.action_input}")

        return rsp_dict["state"]

    async def _act(self) -> Message:
        todo = self.rc.todo
        
        if isinstance(todo, ExecuteCode):
            code, outputs, success, cell_index = await todo.run(self.action_input.strip('\"'), self.execute_nb)
            
            logger.info(f"Code execution {'succeeded' if success else 'failed'}:\nCode:\n {code}\nOutput:\n{outputs}")
            response = f"Code execution {'succeeded' if success else 'failed'}:\nOutput:\n```\n{outputs}\n```"

            if success and self.is_valid_answer:
                self.complete_solution_cell_index = cell_index
                logger.info(f"Updated complete solution cell index to {cell_index}")

        elif isinstance(todo, ProcessText):
            text, result, success, cell_index = await todo.run(self.action_input, self.execute_nb)

            response = f"Text processed: {result}"
            if self.is_valid_answer:
                self.complete_solution_cell_index = cell_index
                logger.info(f"Updated complete solution cell index to {cell_index}")

        elif isinstance(todo, AnalyzeAgents):
            self.agent_profiles_evaluation = self.get_agent_profiles()
            agent_profiles_evaluation = await todo.run(env=self.rc.env, current_profiles=self.agent_profiles_evaluation)
            self.agent_profiles_evaluation = agent_profiles_evaluation
            response = "Agent profiles and scores updated"
        else:
            response = "Invalid action"

        logger.info(f"Response: {response}")

        message = Message(content=response, role=self.name, send_to=self.send_to, cause_by=type(todo))
    
        if self.phase == "act":
            self.clear_memory_except_requirement()
    
        return message

    def clear_memory_except_requirement(self):
        user_requirement = None
        for msg in self.rc.memory.get():
            if "UserRequirement" in msg.cause_by:
                user_requirement = msg
        
        self.rc.memory.clear()
        if user_requirement:
            self.rc.memory.add(user_requirement)
    
    async def get_notebook_content(self) -> str:
        import nbformat
        return nbformat.writes(self.execute_nb.nb)

    def get_complete_solution(self) -> Optional[Dict[str, str]]:
        if self.complete_solution_cell_index is not None:
            cell = self.execute_nb.nb.cells[self.complete_solution_cell_index]
            if cell.cell_type == 'code':
                return {
                    "code": cell.source,
                    "output": '\n'.join(output.get('text', '') for output in cell.outputs if 'text' in output) if cell.outputs else None
                }
            elif cell.cell_type == 'markdown':
                return {
                    "text": cell.source
                }
            else:
                return None
        return None
    
    def set_phase(self, phase: Literal["update", "act"]):
        self.phase = phase
        
    def get_agent_profiles(self):
        return {agent.name: agent.profile for agent in self.rc.env.agents}