from my_agent.agent import Agent
from my_agent.agent_orchestrator import OrchestratorAgent
from my_agent.agent_planner import PlannerAgent, ExecutorAgent
from my_agent.agent import Agent

import copy
import json
from collections import defaultdict

class AgentManager: 
    def __init__(self, env, model, total_num_agents, with_feedback, with_notes, look_ahead_steps, agent_type, executor_models=None):
        
        self.agents = []
        self.total_num_agents = total_num_agents
        self.agent_type = agent_type
        self.price = 0
        self.current_plan = ""
        self.all_tokens = {}

        # Performance tracking
        self.performance_history = defaultdict(lambda: {
            'movement': {'success': 0, 'total': 0},
            'pickup': {'success': 0, 'total': 0},
            'placement': {'success': 0, 'total': 0}
        })
        
        self.predictions = []

        if agent_type == 'orchestrator':
            self.orchestrator = OrchestratorAgent(env, model, -1, total_num_agents, with_feedback, with_notes, look_ahead_steps)
        elif agent_type == 'planner':
            self.planner = PlannerAgent(env, model, -1, total_num_agents, with_feedback, with_notes, look_ahead_steps)
        else:
            pass

        # Create agents for non-auction agent types
        for agent_id in range(total_num_agents):
            if agent_type == 'orchestrator':
                break
            elif agent_type == 'planner':
                # If executor_models is a string, convert it to a list
                if isinstance(executor_models, str):
                    executor_models = [executor_models]
                # Use the appropriate model for this agent
                agent_model = executor_models[agent_id % len(executor_models)]
                agent = ExecutorAgent(env, agent_model, agent_id, total_num_agents, with_feedback, with_notes, look_ahead_steps)
            elif agent_type == 'individual':
                agent = Agent(env, model, agent_id, total_num_agents, with_feedback, with_notes, look_ahead_steps)
            else: 
                raise ValueError("Invalid agent type")
            self.agents.append(agent)
            
            
    def _update_performance(self, agent_id, action_type, success):
        """Update performance metrics for an agent based on action success/failure"""
        if action_type.startswith('goto'):
            category = 'movement'
        elif action_type.startswith('get'):
            category = 'pickup'
        elif action_type.startswith('put'):
            category = 'placement'
        else:
            return  # Skip other action types

        self.performance_history[f'agent{agent_id}'][category]['total'] += 1
        if success:
            self.performance_history[f'agent{agent_id}'][category]['success'] += 1

    def _format_performance_history(self):
        """Format performance history for the planner prompt"""
        history_str = ""
        for agent_id in range(self.total_num_agents):
            agent_key = f'agent{agent_id}'
            history_str += f"Agent{agent_id}:\n"
            for category in ['movement', 'pickup', 'placement']:
                stats = self.performance_history[agent_key][category]
                success_rate = (stats['success'] / stats['total'] * 100) if stats['total'] > 0 else 0
                history_str += f"- {category.capitalize()} success rate: {success_rate:.1f}%\n"
        return history_str

    def step(self, obs, step, verbose=False): 

        if self.agent_type == 'orchestrator':
            plan, price, tokens = self.orchestrate_step(obs, step, verbose=verbose)
        elif self.agent_type == 'planner':
            plan, price, tokens = self.planner_step(obs, step, verbose=verbose)
        elif self.agent_type == 'individual':
            plan, price, tokens = self.individual_step(obs, step, verbose=verbose)
        else:
            raise ValueError("Invalid agent type")

        # Update performance based on feedback
        if hasattr(obs, 'feedback') and obs.feedback:
            for agent_id in range(self.total_num_agents):
                if f'agent{agent_id}' in obs.feedback:
                    # Extract action and success from feedback
                    feedback = obs.feedback[f'agent{agent_id}']
                    action = feedback.get('action', '')
                    success = not feedback.get('error', True)  # True if no error
                    self._update_performance(agent_id, action, success)

        self.price += price
        self.all_tokens = tokens

        return plan
    

    def orchestrate_step(self, obs, step, verbose=False):
        """
        The orchestrator agent will generate a plan of actions
        """

        # Get the actions from the orchestrator agent
        actions, price, tokens = self.orchestrator.step(obs, step, verbose=verbose)
        return actions, price, tokens
    

    def individual_step(self, obs, step, verbose=False):
        """
        The individual agent will generate an action for itself
        """
        all_actions = []
        total_price = 0
        all_tokens = {}
        for agent in self.agents:
            actions, price, tokens = agent.step(obs, step, verbose=verbose)
            all_actions.extend(actions)
            total_price += price
            all_tokens[f'agent{agent.agent_id}'] = tokens

        return all_actions, total_price, all_tokens


    def planner_step(self, obs, step, verbose=False):
        """
        Generate a general plan with the planner agent for all the agents when a new dish is added, removed, or completed. 
        Then generate the actions with the executor agents.
        """
        total_price = 0  # Initialize price
        all_tokens = {}
        step_costs = {'planner': 0, 'executors': []}  # Track costs for this step
        
        # Check if there are dish-related events that require planning
        should_plan = (step == 0 or  # First step always needs a plan
                      obs.just_new_task or  # New dish arrived
                      obs.just_failed or  # Dish removed
                      obs.task_just_success)  # Dish completed
        
        # Generate a new plan if needed
        if should_plan:
            if verbose:
                print(f"Step {step}: Generating new plan due to dish state change")
            
            # Format performance history for the planner
            performance_history = self._format_performance_history()
            
            # Use the planner's plan method with performance history
            self.current_plan, price, tokens = self.planner.plan(
                obs, 
                step, 
                verbose=verbose,
                performance_history=performance_history
            )
            print(f"PLAN[step={step}]: ", self.current_plan)
            total_price += price
            step_costs['planner'] = price  # Store planner cost
            print(f"gpt-4o-v2: ${price:.4f}")
            all_tokens['planner'] = tokens

        # Get actions from all executor agents based on current plan
        all_actions = []
        
        for agent in self.agents:
            # Each executor gets the current plan and decides its own action
            agent_actions, price, tokens = agent.step(obs, step, self.current_plan, verbose=verbose)
            if agent_actions:
                all_actions.extend(agent_actions)
            total_price += price
            model_name = agent.model.split('/')[-1] if '/' in agent.model else agent.model
            step_costs['executors'].append(price)  # Store executor cost
            print(f"{model_name}: ${price:.4f}")
            all_tokens[f'agent{agent.agent_id}'] = tokens

        self.step_costs = step_costs  # Store costs for this step
        return all_actions, total_price, all_tokens


    def store_history(self, filename):

        history = {}

        if self.agent_type == 'orchestrator':
            history['orchestrator'] = self.orchestrator.history
        elif self.agent_type == 'planner':
            history['planner'] = self.planner.history
        else:
            for agent in self.agents:
                history[agent.agent_id] = agent.history

        with open(filename, 'w') as f:
            json.dump(history, f)


