"""
Hierarchical Tree Multi-Agent Communication Architecture System

This module implements a tree topology with hierarchical decision-making structure
where information flows bottom-up from leaf nodes to the root, enabling structured
decision-making processes and distributed problem-solving.
"""

import openai
import json
import time
import math
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime
from config import get_openai_config


@dataclass
class AgentResponse:
    """
    Data structure for agent responses in multi-agent conversations.
    
    Attributes:
        agent_name: Unique identifier for the responding agent
        content: The textual response content
        round_number: Conversation round when response was generated
        timestamp: Time of response generation
        referenced_agents: List of agents whose responses were referenced
    """
    agent_name: str
    content: str
    round_number: int
    timestamp: datetime
    referenced_agents: List[str] = None


class Agent:
    """
    Individual agent implementation for tree topology multi-agent systems.
    
    Each agent maintains its own conversation history and can generate responses
    based on user queries and references to child agents' responses in the hierarchy.
    """
    
    def __init__(self, name: str, system_prompt: str = None):
        """
        Initialize an agent with a unique name and system prompt.
        
        Args:
            name: Unique identifier for the agent
            system_prompt: System-level instructions for the agent's behavior
        """
        self.name = name
        self.system_prompt = system_prompt or (
            "You are an intelligent assistant. Please provide helpful and "
            "informative responses to user questions based on your knowledge "
            "and any provided context."
        )
        self.conversation_history = []
        
    def generate_response(self, user_query: str, reference_responses: List[AgentResponse] = None, 
                         round_number: int = 1) -> AgentResponse:
        try:
            # Build message list for API call
            messages = [{"role": "system", "content": self.system_prompt}]
            
            if round_number == 1:
                # First round: Direct response to user query
                messages.append({
                    "role": "user", 
                    "content": f"User Question: {user_query}\n\nPlease provide your response."
                })
            else:
                # Subsequent rounds: Reference child agents' responses
                reference_text = ""
                if reference_responses:
                    reference_text = "\n\nReference responses from child agents:\n"
                    for resp in reference_responses:
                        if resp.agent_name != self.name:  # Exclude own response
                            reference_text += f"\n{resp.agent_name}'s response:\n{resp.content}\n"
                
                messages.append({
                    "role": "user",
                    "content": f"Original user question: {user_query}\n{reference_text}\n\nPlease refine and optimize your response based on the above information."
                })
            
            # Call OpenAI API
            config = get_openai_config()
            try:
                # Try new OpenAI API client
                from openai import OpenAI
                client = OpenAI(api_key=config['api_key'], base_url=config['api_base'])
                response = client.chat.completions.create(
                    model=config['model'],
                    messages=messages,
                    temperature=config['temperature'],
                    max_tokens=config['max_tokens']
                )
                content = response.choices[0].message.content
            except ImportError:
                # Fallback to legacy API
                response = openai.ChatCompletion.create(
                    model=config['model'],
                    messages=messages,
                    temperature=config['temperature'],
                    max_tokens=config['max_tokens']
                )
                content = response.choices[0].message.content
            except Exception as e:
                # Final fallback for compatibility
                response = openai.ChatCompletion.create(
                    model=config['model'],
                    messages=messages,
                    temperature=config['temperature'],
                    max_tokens=config['max_tokens']
                )
                content = response['choices'][0]['message']['content']
            
            # Create response object
            agent_response = AgentResponse(
                agent_name=self.name,
                content=content,
                round_number=round_number,
                timestamp=datetime.now(),
                referenced_agents=[resp.agent_name for resp in reference_responses] if reference_responses else []
            )
            
            # Save to conversation history
            self.conversation_history.append(agent_response)
            
            return agent_response
            
        except Exception as e:
            print(f"Error generating response for agent {self.name}: {str(e)}")
            return AgentResponse(
                agent_name=self.name,
                content=f"Technical error occurred: {str(e)}",
                round_number=round_number,
                timestamp=datetime.now()
            )
    
    def get_history(self) -> List[AgentResponse]:
        """
        Retrieve the agent's conversation history.
        
        Returns:
            List of AgentResponse objects representing the conversation history
        """
        return self.conversation_history
    
    def clear_history(self):
        """Clear the agent's conversation history."""
        self.conversation_history = []


class TreeAgentSystem:
    """
    Hierarchical tree topology multi-agent system implementation.
    
    This class manages a collection of agents arranged in a tree structure
    with hierarchical decision-making where information flows bottom-up
    from leaf nodes to the root.
    """
    
    def __init__(self, openai_api_key: str, openai_api_base: str = None, num_agents: int = 7):
        """
        Initialize the hierarchical tree multi-agent system.
        
        Args:
            openai_api_key: OpenAI API key for language model access
            openai_api_base: Base URL for OpenAI API (optional)
            num_agents: Number of agents (default: 7, forms 3-layer tree: 1 root + 2 middle + 4 leaves)
        """
        # Configure OpenAI API
        openai.api_key = openai_api_key
        openai.api_base = openai_api_base
        
        # Build tree structure
        self.agents = {}
        self.tree_structure = self._build_tree_structure(num_agents)
        
        # Create agents
        for agent_name in self.tree_structure['all_agents']:
            self.agents[agent_name] = Agent(agent_name)
        
        self.conversation_rounds = []
        self.current_user_query = None
        
    def _build_tree_structure(self, num_agents: int) -> Dict[str, Any]:
        """
        Build the tree structure for the multi-agent system.
        
        Args:
            num_agents: Total number of agents
            
        Returns:
            Dict: Dictionary containing tree structure information
        """
        if num_agents < 1:
            raise ValueError("Number of agents must be at least 1")
        
        # Calculate tree levels and nodes per level
        # Use complete binary tree or near-complete binary tree structure
        levels = []
        remaining = num_agents
        level = 0
        
        while remaining > 0:
            if level == 0:
                # Root node
                level_size = 1
            else:
                # Other levels: maintain balance as much as possible
                max_level_size = 2 ** level
                level_size = min(remaining, max_level_size)
            
            levels.append(level_size)
            remaining -= level_size
            level += 1
        
        # Build node names and hierarchical relationships
        all_agents = []
        level_agents = {}
        agent_counter = 1
        
        for level_idx, level_size in enumerate(levels):
            level_agents[level_idx] = []
            for i in range(level_size):
                if level_idx == 0:
                    agent_name = "Root_Agent"
                else:
                    agent_name = f"Agent_{agent_counter}"
                    agent_counter += 1
                
                all_agents.append(agent_name)
                level_agents[level_idx].append(agent_name)
        
        # Build parent-child relationships
        parent_child_map = {}
        child_parent_map = {}
        
        for level_idx in range(len(levels) - 1):
            current_level = level_agents[level_idx]
            next_level = level_agents[level_idx + 1]
            
            # Assign child nodes to each parent node
            children_per_parent = len(next_level) // len(current_level)
            extra_children = len(next_level) % len(current_level)
            
            child_idx = 0
            for parent_idx, parent in enumerate(current_level):
                parent_child_map[parent] = []
                
                # Basic allocation
                num_children = children_per_parent
                # Earlier parent nodes get extra children
                if parent_idx < extra_children:
                    num_children += 1
                
                for _ in range(num_children):
                    if child_idx < len(next_level):
                        child = next_level[child_idx]
                        parent_child_map[parent].append(child)
                        child_parent_map[child] = parent
                        child_idx += 1
        
        # Leaf nodes (nodes without children)
        leaf_agents = []
        for agent in all_agents:
            if agent not in parent_child_map or len(parent_child_map[agent]) == 0:
                leaf_agents.append(agent)
        
        return {
            'all_agents': all_agents,
            'levels': levels,
            'level_agents': level_agents,
            'parent_child_map': parent_child_map,
            'child_parent_map': child_parent_map,
            'leaf_agents': leaf_agents,
            'root_agent': level_agents[0][0]
        }
    
    def get_children(self, agent_name: str) -> List[str]:
        """Get the child nodes of a specified agent."""
        return self.tree_structure['parent_child_map'].get(agent_name, [])
    
    def get_parent(self, agent_name: str) -> Optional[str]:
        """Get the parent node of a specified agent."""
        return self.tree_structure['child_parent_map'].get(agent_name)
    
    def get_level_agents(self, level: int) -> List[str]:
        """Get all agents at a specified level."""
        return self.tree_structure['level_agents'].get(level, [])
    
    def run_conversation(self, user_query: str, num_rounds: int = 3) -> List[Dict[str, AgentResponse]]:
        """
        Execute a multi-round conversation in hierarchical tree architecture.
        
        Args:
            user_query: The initial question or prompt for the agents
            num_rounds: Number of conversation rounds to execute
            
        Returns:
            List of dictionaries containing agent responses for each round
        """
        self.current_user_query = user_query  # Store user query
        print(f"Starting conversation with {num_rounds} rounds")
        print(f"User Query: {user_query}")
        print("=" * 50)
        
        all_rounds = []
        
        for round_num in range(1, num_rounds + 1):
            print(f"\nRound {round_num}:")
            print("-" * 30)
            
            current_round_responses = {}
            
            # Tree architecture special handling:
            # 1. Start from leaf nodes, process bottom-up
            # 2. Each node can reference its child nodes' responses (if any)
            # 3. Root node processes last, aggregating all information
            
            # Get responses from previous round
            previous_responses = {}
            if round_num > 1 and all_rounds:
                previous_responses = all_rounds[-1]
            
            # Process level by level from bottom to top
            num_levels = len(self.tree_structure['levels'])
            
            for level_idx in range(num_levels - 1, -1, -1):  # From bottom level to root node
                level_agents = self.get_level_agents(level_idx)
                print(f"\nProcessing Level {level_idx + 1}: {', '.join(level_agents)}")
                
                for agent_name in level_agents:
                    agent = self.agents[agent_name]
                    print(f"\n{agent_name} is thinking...")
                    
                    # Collect reference information
                    reference_responses = []
                    
                    if round_num == 1:
                        # First round: Leaf nodes respond directly, non-leaf nodes wait for child responses
                        if agent_name in self.tree_structure['leaf_agents']:
                            # Leaf nodes respond directly to user question
                            pass
                        else:
                            # Non-leaf nodes reference already processed child node responses
                            children = self.get_children(agent_name)
                            for child_name in children:
                                if child_name in current_round_responses:
                                    reference_responses.append(current_round_responses[child_name])
                    else:
                        # Subsequent rounds: Reference child nodes' current round responses and own previous response
                        children = self.get_children(agent_name)
                        for child_name in children:
                            if child_name in current_round_responses:
                                reference_responses.append(current_round_responses[child_name])
                        
                        # Also reference previous round responses from same level
                        if agent_name in previous_responses:
                            parent = self.get_parent(agent_name)
                            if parent and parent in previous_responses:
                                reference_responses.append(previous_responses[parent])
                    
                    response = agent.generate_response(
                        user_query=user_query,
                        reference_responses=reference_responses,
                        round_number=round_num
                    )
                    
                    current_round_responses[agent_name] = response
                    
                    print(f"{agent_name}'s response:")
                    print(response.content)
                    
                    # Display referenced nodes
                    if reference_responses:
                        referenced = [r.agent_name for r in reference_responses]
                        print(f"(Referenced: {', '.join(referenced)})")
                    print()
            
            all_rounds.append(current_round_responses)
            self.conversation_rounds = all_rounds
            
            # Pause between rounds
            if round_num < num_rounds:
                print(f"Round {round_num} completed, preparing next round...")
                time.sleep(1)
        
        print("\nConversation completed")
        return all_rounds
    
    def get_conversation_summary(self) -> Dict[str, Any]:
        """
        Generate a summary of the conversation.
        
        Returns:
            Dictionary containing conversation statistics and tree structure metadata
        """
        if not self.conversation_rounds:
            return {"error": "No conversation has been conducted"}
        
        summary = {
            "total_rounds": len(self.conversation_rounds),
            "agents": list(self.agents.keys()),
            "tree_structure": {
                "levels": self.tree_structure['levels'],
                "root_agent": self.tree_structure['root_agent'],
                "leaf_agents": self.tree_structure['leaf_agents'],
                "parent_child_relationships": self.tree_structure['parent_child_map']
            },
            "rounds_detail": []
        }
        
        for i, round_responses in enumerate(self.conversation_rounds, 1):
            round_detail = {
                "round_number": i,
                "responses": {}
            }
            
            for agent_name, response in round_responses.items():
                round_detail["responses"][agent_name] = {
                    "content_length": len(response.content),
                    "timestamp": response.timestamp.isoformat(),
                    "referenced_agents": response.referenced_agents or []
                }
            
            summary["rounds_detail"].append(round_detail)
        
        return summary
    
    def export_conversation(self, filename: str = None) -> str:
        """
        Export conversation history to JSON file.
        
        Args:
            filename: Output filename (auto-generated if not provided)
            
        Returns:
            String path to the exported file
        """
        if filename is None:
            filename = f"tree_conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        
        export_data = {
            "system_type": "tree_architecture",
            "user_query": self.current_user_query,
            "tree_structure": {
                "levels": self.tree_structure['levels'],
                "root_agent": self.tree_structure['root_agent'],
                "leaf_agents": self.tree_structure['leaf_agents'],
                "parent_child_relationships": self.tree_structure['parent_child_map']
            },
            "conversation_rounds": []
        }
        
        for i, round_responses in enumerate(self.conversation_rounds, 1):
            round_data = {
                "round_number": i,
                "responses": {}
            }
            
            for agent_name, response in round_responses.items():
                round_data["responses"][agent_name] = {
                    "agent_name": response.agent_name,
                    "content": response.content,
                    "round_number": response.round_number,
                    "timestamp": response.timestamp.isoformat(),
                    "referenced_agents": response.referenced_agents or []
                }
            
            export_data["conversation_rounds"].append(round_data)
        
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(export_data, f, ensure_ascii=False, indent=2)
        
        print(f"Conversation exported to: {filename}")
        return filename
    
    def clear_all_history(self):
        """Clear conversation history for all agents and the system."""
        for agent in self.agents.values():
            agent.clear_history()
        self.conversation_rounds = []
        self.current_user_query = None
        print("All conversation history cleared")
    
    def print_tree_structure(self):
        """Print the tree structure visualization."""
        print("🌳 Tree Structure:")
        num_levels = len(self.tree_structure['levels'])
        
        for level_idx in range(num_levels):
            level_agents = self.get_level_agents(level_idx)
            indent = "  " * level_idx
            level_name = "Root Node" if level_idx == 0 else f"Level {level_idx + 1}"
            print(f"{indent}📍 {level_name}: {', '.join(level_agents)}")
            
            # Display parent-child relationships
            if level_idx < num_levels - 1:
                for agent in level_agents:
                    children = self.get_children(agent)
                    if children:
                        child_indent = "  " * (level_idx + 1)
                        print(f"{child_indent}└─ {agent} → {', '.join(children)}")
        print()
