import json
import os
import random
from metagpt.logs import logger
from metagpt.environment import Environment
from metagpt.schema import Message
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.config2 import Config
from pydantic import Field
from typing import Dict, Union, List

from .agents import MultiAgent
from .assistant_agent import AssistantAgent, AnalyzeAgents
from .prompts import generate_update_profile_prompt 
from .utils import get_all_eval_prompt, save_profiles


class MultiAgentEnvironment(Environment):
    agents: list[MultiAgent] = Field(default_factory=list)
    assistant: AssistantAgent = Field(default_factory=AssistantAgent)
    task: str = ""
    task_id: str = ""
    task_type: str = ""
    max_rounds: int = 10
    output_dir: str = ""
    current_round: int = 0
    threshold: float = 0.1
    metrics: tuple = (True, True, True)
    failure_prob: float = 0.0

    def __init__(
            self, 
            num_agents: int = 3, 
            agent_llm_config: Union[Dict, List[Dict]] = None,
            assistant_llm_config: Dict = None,
            max_rounds: int = 10,
            output_dir: str = "",
            threshold: float = 0.1,
            metrics: tuple = (True, True, True),
            failure_prob: float = 0.0,
            **kwargs):
        super().__init__(**kwargs)
        
        self.max_rounds = max_rounds
        self.output_dir = output_dir
        self.threshold = threshold
        self.metrics = metrics
        self.failure_prob = failure_prob

        if isinstance(agent_llm_config, dict):
            # If a single dict is provided, use it for all agents
            agent_configs = [Config.from_llm_config(agent_llm_config) for _ in range(num_agents)]
        elif isinstance(agent_llm_config, list):
            # If a list is provided, use each config for corresponding agent
            if len(agent_llm_config) != num_agents:
                raise ValueError(f"Number of agent configs ({len(agent_llm_config)}) must match number of agents ({num_agents})")
            agent_configs = [Config.from_llm_config(config) for config in agent_llm_config]
        else:
            # If no config is provided, use default (empty) config for all agents
            agent_configs = [Config.from_llm_config({}) for _ in range(num_agents)]

        for i, config in enumerate(agent_configs):
            agent = MultiAgent(
                name=f"Agent_{i}",
                profile=f"Agent_{i}: collaborative agent with unique perspective",
                goal="Collaborate to complete tasks",
                config=config
            )
            self.agents.append(agent)
            logger.info(f"Agent {i}'s LLM: {agent.llm.config}")

        assistant_config = Config.from_llm_config(assistant_llm_config) if assistant_llm_config else None
        self.assistant = AssistantAgent(name="Assistant", config=assistant_config)
        logger.info(f"Assistant's LLM: {self.assistant.llm.config}")
        logger.info(f"Assistant's cost manager: {self.assistant.llm.cost_manager}")
        
        self.add_roles(self.agents + [self.assistant])

    async def warmup(self, max_warmup_rounds: int = 3):
        logger.info("Starting warmup phase")
        for warmup_round in range(max_warmup_rounds):
            logger.info(f"Warmup round: {warmup_round + 1}")
            
            # 执行类似于 Phase 1 的更新过程
            await self.update_agent_profiles(is_warmup=True, warmup_round=warmup_round)
            
            # 检查所有 agent 是否达到阈值
            if self.all_agents_meet_threshold():
                logger.info("All agents have met the threshold. Warmup complete.")
                break
        else:
            logger.warning(f"Warmup ended after {max_warmup_rounds} rounds without all agents meeting the threshold.")
    
    def all_agents_meet_threshold(self):
        for agent in self.agents:
            scores = self.assistant.agent_profiles_evaluation[agent.name]['scores']
            if (scores['clarity'] < self.threshold or
                scores['differentiation'] < self.threshold or
                scores['alignment'] < self.threshold):
                logger.info(f"Agent {agent.name} did not meet threshold: clarity={scores['clarity']}, differentiation={scores['differentiation']}, alignment={scores['alignment']}")
                return False
        return True
    
    async def update_agent_profiles(self, is_warmup: bool = False, warmup_round: int = 0):
        self.assistant.set_phase("update")

        round_message = Message(content=f"{'Warmup' if is_warmup else 'Round'} {warmup_round if is_warmup else 'update'}", 
                                send_to=self.assistant.name, 
                                cause_by=UserRequirement)
        old_profiles_eval = self.assistant.agent_profiles_evaluation
        
        if not is_warmup:
            save_profiles(self.output_dir, self.task_id, old_profiles_eval, self.current_round)
        
        self.assistant.send_to = MESSAGE_ROUTE_TO_ALL        
        await self.assistant.run(with_message=round_message)
         
        role_evaluation = self.assistant.agent_profiles_evaluation
        for agent in self.agents:

            logger.info(f"Updating profile for: {agent.name}")
            if random.random() < self.failure_prob:
                logger.info(f"Agent {agent.name} failed to update profile.")
                continue
            
            new_profiles = self.assistant.get_agent_profiles()
            agent.set_phase("update")
            current_profile = agent.profile
            
            if is_warmup and warmup_round == 0:
                scores = role_evaluation[agent.name]['scores']
                clarity_eval, differentiation_eval, alignment_eval = get_all_eval_prompt(
                    scores['clarity'], None,
                    scores['differentiation'], None,
                    scores['alignment'], None,
                    profiles=new_profiles,
                    agent_name=agent.name
                )

                if not self.metrics[0]:
                    clarity_eval = None
                if not self.metrics[1]:
                    differentiation_eval = None
                if not self.metrics[2]:
                    alignment_eval = None
                
                prompt = generate_update_profile_prompt(
                    clarity_eval_prompt=clarity_eval,
                    differentiation_eval_prompt=differentiation_eval,
                    alignment_eval_prompt=alignment_eval
                )
            else:
                new_scores = role_evaluation[agent.name]['scores']
                old_scores = old_profiles_eval[agent.name]['scores']
                old_profile = old_profiles_eval[agent.name]['profile']

                clarity_eval, differentiation_eval, alignment_eval = get_all_eval_prompt(
                    new_scores['clarity'], new_scores["clarity"] - old_scores["clarity"],
                    new_scores['differentiation'], new_scores["differentiation"] - old_scores["differentiation"],
                    new_scores['alignment'], new_scores["alignment"] - old_scores["alignment"],
                    profiles=new_profiles,
                    agent_name=agent.name
                )

                if not self.metrics[0]:
                    clarity_eval = None
                    logger.info(f"Clarity evaluation is disabled for {agent.name}") 
                if not self.metrics[1]:
                    differentiation_eval = None
                    logger.info(f"Differentiation evaluation is disabled for {agent.name}")
                if not self.metrics[2]:
                    alignment_eval = None
                    logger.info(f"Alignment evaluation is disabled for {agent.name}")

                prompt = generate_update_profile_prompt(
                    old_profile=old_profile,
                    clarity_eval_prompt=clarity_eval,
                    differentiation_eval_prompt=differentiation_eval,
                    alignment_eval_prompt=alignment_eval
                )
            
            logger.info(f"Prompt for {agent.name}: {prompt}")
            update_message = Message(content=prompt, send_to=agent.name, cause_by=AnalyzeAgents)

            await agent.run(with_message=update_message)
    
        logger.info(f"New profiles after assistant update: {json.dumps(self.assistant.agent_profiles_evaluation, indent=2)}")

    
    async def run(self, task, is_warmup: bool = True) -> Dict:
        self.task = task["complete_prompt"]
        self.task_id = task["task_id"]

        self.publish_message(
            Message(
                role="Human",
                content=self.task,
                cause_by=UserRequirement,
                send_to=MESSAGE_ROUTE_TO_ALL,
            ),
            peekable=False,
        )
        self.assistant.user_requirement = self.task

        if is_warmup:
            logger.info("Starting warmup phase")
            await self.warmup()
            logger.info("Warmup phase completed")

        for rnd in range(self.max_rounds):
            logger.info("=" * 30 + f"Round: {rnd + 1}/{self.max_rounds}" + "=" * 30)
            
            # Phase 1: Profile Update
            logger.info("=" * 30 + "Phase 1: Profile Update" + "=" * 30)
            self.current_round = rnd
            await self.update_agent_profiles()
            
            # Phase 2: Task Execution
            logger.info("="*30 + "Phase 2: Task Execution" + "="*30)

            self.assistant.set_phase("act")
            self.publish_message(
                Message(
                    role="Human",
                    content=f"Round {rnd} of {self.max_rounds}",
                    cause_by=UserRequirement,
                    send_to=MESSAGE_ROUTE_TO_ALL,
                ),
                peekable=False,
            )
            for agent in self.agents:
                # failure prob 
                agent.rnd = rnd
                agent.is_valid_answer = False
                logger.info("=" * 30 + f"Running role: {agent.name}" + "=" * 30)
                agent.set_phase("act")
                if random.random() < self.failure_prob:
                    agent.is_finished = False 
                    logger.info(f"Agent {agent.name} failed to execute the task.")
                    continue
                
                await agent.run()

                agent_message = agent.get_memories()[-1]
                if agent_message and "Execute" in agent_message.cause_by:
                    self.assistant.is_valid_answer = agent.is_valid_answer
                    self.assistant.send_to = agent.name
                    await self.assistant.run()
                    agent.set_phase("reflexion")
                    await agent.run()
                        

            is_idle = True
            for agent in self.agents:
                logger.info(f"Agent {agent.name} is idle: {agent.is_finished}")
                if not agent.is_finished:
                    is_idle = False
                    break

            if is_idle:
                break       
        

        solution = {"code": "", "output": "", "finished_round": rnd}
        try:
            tmp_solution = self.assistant.get_complete_solution()
            if tmp_solution is not None:
                solution = tmp_solution
                solution["finished_round"] = rnd
            logger.info("Final Notebook Content:")
            logger.info(solution)
        except Exception as e:
            logger.error(f"Error getting complete solution: {e}")
        finally:
            await self.assistant.execute_nb.terminate()

        logger.info(f"Assistant cost: {self.assistant.llm.cost_manager.get_costs()}")
        for agent in self.agents:
            logger.info(f"{agent.name} cost: {agent.llm.cost_manager.get_costs()}")

        return {
            "task_id": self.task_id,
            "attempt_answer": solution,
            "finished_round": solution["finished_round"],
            **task
        }
    
    async def run_evolving_tasks(self, task_set: dict):
        set_name = task_set['set_name']
        tasks = task_set['tasks']
        results = []
        is_warmup = True

        for task_index, task in enumerate(tasks):
            if task_index > 0:
                is_warmup = False
            task_name = task['name']
            task_description = task['description']
            
            logger.info(f"Starting task {task_index + 1} in set '{set_name}': {task_name}")
            
            # 更新 assistant 的 user_requirement
            self.assistant.user_requirement = f"Task: {task_name}\n\nDescription: {task_description}"
            
            # 清除所有 agent 的记忆，但保留 profile
            for agent in self.agents:
                await agent.run()
                agent.rc.memory.clear()
            
            await self.assistant.run()
            # 清除 assistant 的记忆
            self.assistant.rc.memory.clear()
            
            # 运行单个任务
            if "complete_prompt" in task:
                del task["complete_prompt"]
            task_data = {
                "task_id": f"{set_name}-{task_index}",
                "complete_prompt": f"Task: {task_name}\n\nDescription: {task_description}",
                **task
            }

            result = await self.run(task_data, is_warmup)
            
            # 将结果添加到列表中
            results.append(result)
            
            logger.info(f"Task {task_name} completed.")
    
        return results
    