from my_agent.agent import Agent
from utils.llm import rules, prepend_history, query_llm
from levels.utils import convert_to_prompt
from prompts.orchestrator_agent_action_only import orchestrator_examples, orchestrator_info_prompt

class OrchestratorAgent(Agent):
    def __init__(self, env, model, agent_id, total_num_agents, with_feedback, with_notes, look_ahead_steps, budget, costs):
        self.env = env
        self.model = model
        self.agent_id = agent_id
        self.total_num_agents = total_num_agents
        self.with_feedback = with_feedback
        self.with_notes = with_notes
        self.look_ahead_steps = look_ahead_steps

        self.budget = budget
        self.costs = costs
        self.history, self.feedback, self.suggestions = self.initialize_prompt()
        self.initial_history_length = len(self.history)
        self.prompt_history = []
        self.previous_actions = []
        
        # Initialize success rate tracking
        self.action_success_rates = {}  # Overall action success rates
        self.action_attempts = {}
        self.action_successes = {}
        
        # Initialize agent-specific success rate tracking
        self.agent_action_success_rates = {}  # Per-agent action success rates
        self.agent_action_attempts = {}
        self.agent_action_successes = {}
        
        # Initialize all agent tracking dictionaries
        for agent_idx in range(self.total_num_agents):
            self.agent_action_success_rates[agent_idx] = {}
            self.agent_action_attempts[agent_idx] = {}
            self.agent_action_successes[agent_idx] = {}


    def initialize_prompt(self):
        pre_prompt = ("user" , rules(self.env, self.with_notes))
        info_prompt = ("user", f"{orchestrator_info_prompt.format(total_num_agents=self.total_num_agents, budget=self.budget, costs=self.costs_to_txt(self.costs))}")
        history = [pre_prompt] + orchestrator_examples + [info_prompt]

        if self.with_feedback:
            feedback = '-execution error messages:\n  --  []\n'
            suggestions = '-execution suggestions:\n  --  []\n'
        else:
            feedback = ''
            suggestions = ''

        return history, feedback, suggestions

    
    def step(self, obs, step, cost, budget, verbose=False):
        """
        Generate a plan for the orchestrator agent for a given step.
        Args:
            obs (object): The current observation of the environment.
            step (int): The current step number in the simulation.
            verbose (bool, optional): If True, enables verbose logging. Defaults to False.

        Returns:
            string: A plan for each of the  agents.
        """
        
        # Update success rates if we have previous actions
        if step != 0 and self.previous_actions:
            # Get the action success information directly from the environment
            action_successes = self.env.action_success_history[-1] if hasattr(self.env, 'action_success_history') and self.env.action_success_history else None
            self.update_success_rates(self.previous_actions)
            
            # Log action success information if available
            if action_successes and verbose:
                for action, success in zip(self.previous_actions, action_successes):
                    print(f"Action {action}: {'Success' if success else 'Failed'}")
        
        # update history
        if self.with_feedback and step != 0:
            self.feedback = '-execution error messages:\n  --  ' + str(self.env.feedback) + '\n'
            # print(f"FEEDBACK[agent{self.agent_id}]: ", self.feedback)
        if self.with_notes and step !=0:
            self.suggestions = '-execution suggestions:\n  --  ' + str(self.env.suggestions) + '\n'
            if 'agent ids cannot be the same' in self.feedback:
                self.suggestions += f'  --  You can only control and plan the actions for agent{self.agent_id}. \n'
            # print(f"SUGGESTIONS[agent{self.agent_id}]: ", self.suggestions)
        
        # Add success rates to the prompt
        success_rates_text = self.get_success_rates_text()
                    
        if step == 0:            
            prompt = self.feedback + self.suggestions + convert_to_prompt(obs) + '-action:\n'
        else:
            prompt = self.actions_to_txt(self.previous_actions) + self.feedback + self.suggestions + convert_to_prompt(obs)  + '-action:\n'
        print("PROMPT: ", prompt)

        # cap message length
        if len(self.history) < self.look_ahead_steps + self.initial_history_length:
            self.history = prepend_history(self.history, prompt, verbose=verbose)
        else:
            self.history = (self.history[:self.initial_history_length] + 
                          self.history[-(self.look_ahead_steps-1):])
            self.history = prepend_history(self.history, prompt, verbose=verbose)


        # GENERATE ACTION

        # print(f"HISTORY[agent{self.agent_id}]:", self.history)
        action = query_llm(self.history, model=self.model)
        print(f"ACTION[agent{self.agent_id}]:", action)

        try:
            parsed_actions = self.extract_actions(action)
        except:
            parsed_actions = []

        if parsed_actions:
            self.update_history(parsed_actions, role='assistant', verbose=verbose)
            self.previous_actions = parsed_actions

        return parsed_actions


    def update_success_rates(self, actions):
        """
        Update the success rates based on action success information from the environment.
        
        Args:
            actions (list): List of actions that were attempted
        """
        # Get action success information from the environment
        action_successes = self.env.action_success_history[-1] if self.env.action_success_history else []
        
        # Process each action with its success status
        for i, action in enumerate(actions):
            # Skip if we don't have success information for this action
            if i >= len(action_successes):
                continue
                
            # Extract the action type (e.g., 'move', 'grab', etc.)
            action_parts = action.split('_')
            action_type = action_parts[0]
            
            # Extract agent ID if present in the action
            agent_id = None
            for part in action_parts:
                if part.startswith('agent') and len(part) > 5:
                    try:
                        agent_id = int(part[5:])
                        break
                    except ValueError:
                        pass
            
            # If no agent ID found, use the current agent's ID
            if agent_id is None:
                agent_id = self.agent_id
                
            # Ensure agent_id is within valid range
            if agent_id >= self.total_num_agents:
                continue
            
            # Update overall action stats
            if action_type not in self.action_attempts:
                self.action_attempts[action_type] = 0
                self.action_successes[action_type] = 0
            
            self.action_attempts[action_type] += 1
            if action_successes[i]:
                self.action_successes[action_type] += 1
            
            self.action_success_rates[action_type] = self.action_successes[action_type] / self.action_attempts[action_type]
            
            # Update agent-specific action stats
            if action_type not in self.agent_action_attempts[agent_id]:
                self.agent_action_attempts[agent_id][action_type] = 0
                self.agent_action_successes[agent_id][action_type] = 0
            
            self.agent_action_attempts[agent_id][action_type] += 1
            if action_successes[i]:
                self.agent_action_successes[agent_id][action_type] += 1
            
            self.agent_action_success_rates[agent_id][action_type] = (
                self.agent_action_successes[agent_id][action_type] / 
                self.agent_action_attempts[agent_id][action_type]
            )
    
    def get_success_rates_text(self):
        """
        Generate a text representation of the current success rates.
        
        Returns:
            str: Text representation of success rates
        """
        if not self.action_success_rates and not any(self.agent_action_success_rates.values()):
            return "-action success rates: No data available yet\n"
        
        text = "-action success rates:\n"
        
        # Overall action success rates
        if self.action_success_rates:
            text += "  -- Overall:\n"
            for action_type, rate in self.action_success_rates.items():
                text += f"     * {action_type}: {rate:.2f} ({self.action_successes[action_type]}/{self.action_attempts[action_type]} successful)\n"
        
        # Per-agent action success rates
        for agent_id in range(self.total_num_agents):
            if self.agent_action_success_rates[agent_id]:
                text += f"  -- Agent {agent_id}:\n"
                for action_type, rate in self.agent_action_success_rates[agent_id].items():
                    text += f"     * {action_type}: {rate:.2f} ({self.agent_action_successes[agent_id][action_type]}/{self.agent_action_attempts[agent_id][action_type]} successful)\n"
        
        return text + "\n"