
from typing import Dict, List, Optional, Tuple, Union
from pydantic import Field
from src.utils.logsetup import logger

from src.agent import BaseAgent, AutoAgent, ToolActorAgent, ToolMakerAgent
from src.schema.message import Memory
from src.tools.task_classifier import TaskClassifier

class AutoAgentRLFLow:
    """A flow for auto Agent reinforcement learning."""
    
    def __init__(
        self,
        agents: Union[BaseAgent, Dict[str, BaseAgent]],
    ):
        # Handle different ways of providing agents
        if isinstance(agents, AutoAgent):
            self.agents_dict = {"autogent": agents}
        else:
            self.agents_dict = agents
        
        # Initialize task classifier
        self.task_classifier = TaskClassifier()

    def get_auto_agent(
        self,
        all_tools: Optional[List[Dict]],
        toolactor: Optional[ToolActorAgent] = None
    ) -> AutoAgent:
        """Get the primary agent."""
        autoagent = self.agents_dict.get("autogent")
        if all_tools:
            autoagent.set_tools(all_tools)
        if toolactor:
            autoagent.set_toolactor(toolactor)
        return autoagent
    
    def get_tool_maker_agent(self) -> ToolMakerAgent:
        """Get the tool maker agent."""
        return self.agents_dict.get("toolmaker")
    
    def get_tool_actor_agent(self) -> ToolActorAgent:
        """Get the tool actor agent."""
        return self.agents_dict.get("toolactor")
        
    async def run(
        self,
        request: Optional[str] = None,
        gold_tool: Optional[str] = None,
        gold_answer: Optional[str] = None,
        **kwargs
    ) -> str:
        """Run the flow."""
        try:
            # Step 1: Classify the task (only at the beginning)
            logger.info("🔍 Step 1: Classifying task...")
            task_classification = await self.task_classifier.execute(task_query=request)
            logger.info(f"📊 Task Classification Result: {task_classification}")
            
            # Step 2: Generate tools based on classification
            logger.info("🛠️ Step 2: Generating specialized tools...")
            toolmaker = self.get_tool_maker_agent()
            if not gold_tool:
                all_tools = await toolmaker.run(
                    task_query=request,
                    gold_answer=gold_answer,
                    task_classification=task_classification  # Pass classification to toolmaker
                )
            else:
                all_tools = gold_tool

            # Step 3: Execute the task with AutoAgent
            logger.info("🤖 Step 3: Executing task with AutoAgent...")
            toolactor = self.get_tool_actor_agent()
            autoagen = self.get_auto_agent(all_tools, toolactor)
            
            # Pass classification info to AutoAgent for strategy guidance
            autoagen.task_classification = task_classification
            
            # Pass gold answer to AutoAgent for toolactor guidance
            autoagen.gold_answer = gold_answer

            # Execute react flow
            await autoagen.run(request)
            trajectory_messages = Memory.to_sharegpt(autoagen.messages, tools=all_tools)
            
            # Print token summary at the end
            self._print_token_summary(autoagen)
            
            return trajectory_messages

        except Exception as e:
            logger.error(e)
            return f"Error {e} occurred during the flow."
    
    def _print_token_summary(self, autoagen: AutoAgent):
        """Log task token usage summary to log system"""
        try:
            # Collect token statistics from all LLM instances
            llm_instances = []
            
            # AutoAgent's LLM
            if hasattr(autoagen, 'llm') and autoagen.llm:
                llm_instances.append(('AutoAgent', autoagen.llm))
            
            # ToolActor's LLM  
            if hasattr(autoagen, 'toolactor') and autoagen.toolactor and hasattr(autoagen.toolactor, 'toolactor_llm'):
                llm_instances.append(('ToolActor', autoagen.toolactor.toolactor_llm))
            
            # ToolMaker's LLM
            toolmaker = self.get_tool_maker_agent()
            if hasattr(toolmaker, 'llm') and toolmaker.llm:
                llm_instances.append(('ToolMaker', toolmaker.llm))
            
            # TaskClassifier's LLM
            if hasattr(self.task_classifier, 'llm') and self.task_classifier.llm:
                llm_instances.append(('TaskClassifier', self.task_classifier.llm))
            
            # Calculate totals and build detailed breakdown
            total_input = 0
            total_output = 0
            breakdown_data = {}
            
            for name, llm in llm_instances:
                if llm and hasattr(llm, 'total_input_tokens') and hasattr(llm, 'total_output_tokens'):
                    input_tokens = llm.total_input_tokens
                    output_tokens = llm.total_output_tokens
                    subtotal = input_tokens + output_tokens
                    
                    total_input += input_tokens
                    total_output += output_tokens
                    
                    breakdown_data[name] = {
                        "input": input_tokens,
                        "output": output_tokens, 
                        "subtotal": subtotal
                    }
            
            total_tokens = total_input + total_output
            
            # Log detailed breakdown
            logger.info("📊 TASK_COMPLETED_TOKEN_SUMMARY_START")
            
            for name, data in breakdown_data.items():
                logger.info(f"📊 TOKEN_BREAKDOWN | {name}: Input={data['input']:,} Output={data['output']:,} Subtotal={data['subtotal']:,}")
            
            logger.info(f"📊 TOKEN_GRAND_TOTAL | TotalInput={total_input:,} TotalOutput={total_output:,} TotalTokens={total_tokens:,}")
            
            # Also log a compact summary for easy parsing
            logger.info(f"📊 TOKEN_SUMMARY_COMPACT | INPUT={total_input} OUTPUT={total_output} TOTAL={total_tokens}")
            
            logger.info("📊 TASK_COMPLETED_TOKEN_SUMMARY_END")
            
        except Exception as e:
            logger.warning(f"Error logging token summary: {e}")
            # Fallback: try to log AutoAgent's statistics at least
            if hasattr(autoagen, 'llm') and autoagen.llm and hasattr(autoagen.llm, 'get_task_token_summary'):
                logger.info(autoagen.llm.get_task_token_summary())



        