from typing import Literal
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.actions.add_requirement import UserRequirement
import json

from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from .actions import Feedback, Execute, Update, Skip
from .assistant_agent import ExecuteCode, ProcessText
from .prompts import REACT_THINK_PROMPT
from metagpt.logs import logger
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.utils.common import remove_white_spaces
from .utils import custom_json_parser

class MultiAgent(Role):
    react_mode: Literal["plan_and_act", "react"] = "react"
    max_react_loop: int = 1
    phase: Literal["update", "act", "reflexion"] = "update"
    latest_execution: str = ""
    latest_execution_result: str = ""
    feedback: dict = {}
    is_valid_answer: bool = False
    is_finished: bool = False
    rnd: int = 0

    def __init__(self, name: str, profile: str, goal: str, constraints: str = "", **kwargs):
        super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, **kwargs)
        self.set_actions([Update(), Execute(), Skip(), Feedback()])
        self._watch([Execute, UserRequirement, ExecuteCode, ProcessText])
        self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop)
    
    def _should_add_to_memory(self, msg: Message) -> bool:
        criteria = {
            "Skip": False,
            "Update": False,
            "Feedback": False,
            "Execute": True,
            "UserRequirement": True,
            "ExecuteCode": True,
            "ProcessText": True,
            "AnalyzeAgents": False,
        }
        
        logger.info(f"Message cause_by: {msg.cause_by}")
        logger.info(f"Message send_to: {msg.send_to}")

        for cause, should_add in criteria.items():
            if cause in msg.cause_by:
                logger.info(f"Message caused by {cause}, {'adding to' if should_add else 'not adding to'} memory")
                return should_add

        if self.name in msg.send_to or MESSAGE_ROUTE_TO_ALL in msg.send_to:
            logger.info(f"Message is sent to {self.name if self.name in msg.send_to else 'all'}, adding to memory")
            return True
    
        logger.info("Message does not meet any criteria for adding to memory")
        return False

    async def _observe(self, ignore_memory=False) -> int:
        """Prepare new messages for processing from the message buffer and other sources."""
        # Read unprocessed messages from the msg buffer.
        news = []
        if self.recovered:
            news = [self.latest_observed_msg] if self.latest_observed_msg else []
        if not news:
            news = self.rc.msg_buffer.pop_all()
        
        self.rc.news = []
        # Store the read messages in your own memory to prevent duplicate processing.
        old_messages = [] if ignore_memory else self.rc.memory.get()
        
        # self.latest_execution_result = ""
        # Filter messages based on custom logic
        for msg in news:
            is_valid = self._should_add_to_memory(msg)
            logger.info(f"Is valid: {is_valid}")
            if is_valid and msg not in old_messages:
                self.rc.memory.add(msg)
            if ("ExecuteCode" in msg.cause_by or "ProcessText" in msg.cause_by) and self.name in msg.send_to:
                self.latest_execution_result = msg.content
            if ("AnalyzeAgents" in msg.cause_by and self.name in msg.send_to) or "Update" in msg.cause_by:
                self.rc.news.append(msg)
        
        
        # During act phase, process all relevant messages in rc.news
        for n in news:
            if (n.cause_by in self.rc.watch and (self.name in n.send_to or MESSAGE_ROUTE_TO_ALL in n.send_to)) and n not in old_messages:
                self.rc.news.append(n)
        logger.info(f"Act phase observed: {len(self.rc.news)}")
        
        
        self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None  # record the latest observed msg

        # Log observed messages
        news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
        if news_text:
            logger.debug(f"{self._setting} observed: {news_text}")
        
        return len(self.rc.news)
    
    @retry(
    stop=stop_after_attempt(3),
    wait=wait_fixed(1),
    retry=retry_if_exception_type(Exception),
    reraise=True
    )
    async def _think(self) -> bool:
        if self.phase == "update":
            if self.rc.news:
                self.set_todo(self.actions[0])  # Update action
                return True
            return False
        
        if self.phase == "reflexion":
            if self.latest_execution != "":
                self.set_todo(self.actions[3])
                return True
            return False
        
        memories = self.get_memories()
        user_requirement = memories[0].content if memories else ""
        context = "\n".join([f"{msg.role}: {msg.content}" for msg in memories[1:]])

        prompt = REACT_THINK_PROMPT.format(
            user_requirement=user_requirement,
            context=context
        )

        rsp = await self.llm.aask(prompt, response_format= { "type": "json_object" })

        try:
            rsp_dict = custom_json_parser(rsp)
            logger.info(f"Agent {self.name} react response: {rsp_dict}")

            self.rc.memory.add(Message(content=rsp_dict["thoughts"], role=self.name))
            
            choice = remove_white_spaces(rsp_dict["action"]).upper()

            if choice == "EXECUTE":
                self.set_todo(self.actions[1])
            elif choice == "SKIP":
                self.set_todo(self.actions[2])  # Skip action
            else:
                return False

            if self.rnd == 0:
                state = True
            else:
                state = state = rsp_dict["state"]
            logger.info(f"Agent {self.name} state: {state}")
            self.is_finished = not state
        except Exception as e: 
            logger.error(f"Failed to handle react response: {e}")
            return True
        
        return state
    
    @retry(
    stop=stop_after_attempt(3),
    wait=wait_fixed(1),
    retry=retry_if_exception_type(Exception),
    reraise=True
    )
    async def _act(self) -> Message:
        todo = self.rc.todo
        memories = self.get_memories()
        context = "\n".join([f"{msg.role}: {msg.content}" for msg in memories])
        logger.info(f"Todo: {todo}")
        
        if isinstance(todo, Execute):
            if type(self.feedback) == str:
                try:
                    self.feedback = custom_json_parser(self.feedback)
                except json.JSONDecodeError:
                    logger.error(f"Failed to parse JSON from feedback: {self.feedback}")
                    self.feedback = {}
                
            if self.feedback.get("self_reflection", "") != "":
                feedback = self.feedback
                logger.info(f"Feedback: {feedback}")
                logger.info(f"Feedback type: {type(feedback)}")
            else:
                feedback = {"self_reflection": "No feedback available"} 
                logger.info("No feedback found")

            action_result = await todo.run(feedback["self_reflection"], context=context)
            logger.info(f"Action result: {action_result}")

            try:
                result_dict = custom_json_parser(action_result)
                execution_content = result_dict["execution_content"]
                self.is_valid_answer = result_dict["is_valid"]
                action_result = execution_content
            except json.JSONDecodeError:
                logger.error(f"Failed to parse JSON from action result: {action_result}")
                action_result = "Invalid execution result format"
            
        elif isinstance(todo, Update):
            user_requirement = memories[0].content if context else ""
            previous_updates = []
            update_message = ""
            latest_execution = ""
            latest_execution_result = ""
            feedback = ""

            for msg in self.rc.news:
                if "Update" in msg.cause_by:
                    previous_updates.append({"agent": msg.role, "profile": msg.content})
                if "AnalyzeAgents" in msg.cause_by and self.name in msg.send_to:
                    update_message = msg.content

            logger.info(f"Update message: {update_message}")
            
            if self.latest_execution != "":
                latest_execution = self.latest_execution
                logger.info(f"Latest execution: {latest_execution}")
            else:
                latest_execution = None
                logger.info("No latest execution found")

            if self.latest_execution_result != "":
                latest_execution_result = self.latest_execution_result
                logger.info(f"Latest execution result: {latest_execution_result}")
            else:
                latest_execution_result = None
                logger.info("No latest execution result found")

            if len(previous_updates) == 0:
                logger.info(f"No previous updates found for agent {self.name}")

            logger.info(f"Previous updates: {previous_updates}")
            if type(self.feedback) == str:
                try:
                    self.feedback = custom_json_parser(self.feedback) 
                except json.JSONDecodeError:
                    logger.error(f"Failed to parse JSON from feedback: {self.feedback}")
                    self.feedback = {}
            if self.feedback.get("self_reflection", "") != "":
                feedback = self.feedback
                logger.info(f"Feedback: {feedback}")
            else:
                feedback = {"self_reflection": "No feedback available"}
                logger.info("No feedback found")
            
            updated_profile = await todo.run(user_requirement, update_message, previous_updates, latest_execution, feedback)
            updated_profile = updated_profile.replace("I", "You").replace("my", "your").replace("I'm", "You're").replace("My", "Your")
            
            self.profile = updated_profile  # Update self-description
            self.llm.system_prompt = self._get_prefix()
            logger.info(f"Updated system prompt: {self.llm.system_prompt}")
            action_result = updated_profile

        elif isinstance(todo, Skip):
            return Message(content="Skipped", role=self.name, cause_by=Skip)
        
        elif isinstance(todo, Feedback):
            action_result = await todo.run(context=context, execution=self.latest_execution, execution_result=self.latest_execution_result)
            logger.info(f"Feedback result: {action_result}")

        else:
            action_result = "Invalid action"

        msg = Message(content=action_result, role=self.name, cause_by=type(todo))

        if self.phase == "act":
            if "Execute" in msg.cause_by:
                self.latest_execution = action_result
            else:
                self.latest_execution = ""
                self.latest_execution_result = ""
            

        if self.phase != "update":
            self.feedback = action_result if "Feedback" in msg.cause_by else ""
        
        logger.info(f"Action result message: {msg}")
        if "Execute" in msg.cause_by:
            self.rc.memory.add(msg)

        return msg
    
    def set_phase(self, phase: Literal["update", "act"]):
        self.phase = phase