import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import math
from typing import Dict, List, Tuple, Any

# Import the topology classes
try:
    from topologies import DragonFlyGraph, FatTreeGraph, prepare_network_topology
except ImportError:
    print("Make sure topologies.py is in the same directory")
    raise


class TopologyVisualizer:
    """
    A comprehensive visualizer for network topologies that groups nodes by their 
    groups and uses different colors for intra-group and inter-group connections.
    """
    
    def __init__(self):
        self.color_palette = plt.cm.Set3
        self.server_color_palette = plt.cm.tab10
        
    def get_node_groups(self, graph: nx.Graph) -> Dict[int, List[str]]:
        """Extract nodes grouped by their group attribute."""
        groups = {}
        for node, data in graph.nodes(data=True):
            if 'group' in data:
                group_id = data['group']
                if group_id not in groups:
                    groups[group_id] = []
                groups[group_id].append(node)
        return groups
    
    def classify_edges(self, graph: nx.Graph) -> Tuple[List, List, List]:
        """
        Classify edges into three types:
        - server_to_router: edges connecting servers to routers
        - intra_group: edges within the same group
        - inter_group: edges between different groups
        """
        server_to_router = []
        intra_group = []
        inter_group = []
        
        for u, v in graph.edges():
            u_data = graph.nodes[u]
            v_data = graph.nodes[v]
            
            # Check if one is server and other is router (any type)
            u_type = u_data.get('type', 'unknown')
            v_type = v_data.get('type', 'unknown')
            
            router_types = ['router', 'edge_router', 'aggregate_router']
            
            if (u_type == 'server' and v_type in router_types) or (v_type == 'server' and u_type in router_types):
                server_to_router.append((u, v))
            elif 'group' in u_data and 'group' in v_data:
                if u_data['group'] == v_data['group']:
                    intra_group.append((u, v))
                else:
                    inter_group.append((u, v))
            else:
                # Handle special cases like spine routers
                inter_group.append((u, v))
                
        return server_to_router, intra_group, inter_group
    
    def grouped_circular_layout(self, graph: nx.Graph, outer_radius: float = 12.0, 
                              inner_radius: float = 3.0) -> Dict[str, Tuple[float, float]]:
        """
        Create a layout where each group is arranged in a circle on the outer radius,
        and nodes within each group are arranged in a smaller circle.
        """
        groups = self.get_node_groups(graph)
        pos = {}
        
        # Handle nodes without group (like spine routers)
        ungrouped_nodes = []
        for node in graph.nodes():
            if not any(node in group_nodes for group_nodes in groups.values()):
                ungrouped_nodes.append(node)
        
        # Position grouped nodes
        group_ids = sorted(groups.keys())
        n_groups = len(group_ids)
        
        for idx, group_id in enumerate(group_ids):
            # Center of the group
            if n_groups > 1:
                angle = 2 * math.pi * idx / n_groups
            else:
                angle = 0
            cx, cy = outer_radius * math.cos(angle), outer_radius * math.sin(angle)
            
            # Get servers and routers separately for better positioning
            group_nodes = groups[group_id]
            servers = [n for n in group_nodes if graph.nodes[n].get('type') == 'server']
            routers = [n for n in group_nodes if graph.nodes[n].get('type') == 'router']
            
            # Position servers in inner circle
            for j, server in enumerate(servers):
                if len(servers) > 1:
                    theta = 2 * math.pi * j / len(servers)
                else:
                    theta = 0
                pos[server] = (
                    cx + (inner_radius * 0.6) * math.cos(theta),
                    cy + (inner_radius * 0.6) * math.sin(theta)
                )
            
            # Position routers in outer circle of the group
            for j, router in enumerate(routers):
                if len(routers) > 1:
                    theta = 2 * math.pi * j / len(routers)
                else:
                    theta = 0
                pos[router] = (
                    cx + inner_radius * math.cos(theta),
                    cy + inner_radius * math.sin(theta)
                )
        
        # Position ungrouped nodes (like spine routers) at the center
        if ungrouped_nodes:
            center_radius = 1.0
            for i, node in enumerate(ungrouped_nodes):
                if len(ungrouped_nodes) > 1:
                    angle = 2 * math.pi * i / len(ungrouped_nodes)
                    pos[node] = (center_radius * math.cos(angle), center_radius * math.sin(angle))
                else:
                    pos[node] = (0, 0)
        
        return pos
    
    def hierarchical_layout(self, graph: nx.Graph) -> Dict[str, Tuple[float, float]]:
        """
        Create a hierarchical layout suitable for fat-tree topologies.
        """
        pos = {}
        
        # Separate nodes by type and group
        servers = []
        edge_routers = []
        aggregate_routers = []
        spine_routers = []
        top_spine_routers = []
        other_routers = []
        
        for node, data in graph.nodes(data=True):
            node_type = data.get('type', 'unknown')
            if node_type == 'server':
                servers.append(node)
            elif node_type == 'edge_router':
                edge_routers.append(node)
            elif node_type == 'aggregate_router':
                aggregate_routers.append(node)
            elif node_type == 'spine_router':
                spine_routers.append(node)
            elif node_type == 'top_spine_router':
                top_spine_routers.append(node)
            elif 'spine' in node:
                spine_routers.append(node)
            else:
                other_routers.append(node)
        
        # Position servers at the bottom (level 0)
        for i, server in enumerate(servers):
            pos[server] = (i * 2, 0)
        
        # Position edge routers at level 1
        edge_y = 2
        groups = self.get_node_groups(graph)
        for group_id in sorted(groups.keys()):
            group_edge_routers = [n for n in groups[group_id] if n in edge_routers]
            for j, router in enumerate(group_edge_routers):
                pos[router] = (group_id * 8 + j * 2, edge_y)
        
        # Position aggregate routers at level 2
        aggregate_y = 4
        for group_id in sorted(groups.keys()):
            group_aggregate_routers = [n for n in groups[group_id] if n in aggregate_routers]
            for j, router in enumerate(group_aggregate_routers):
                pos[router] = (group_id * 8 + j * 2, aggregate_y)
        
        # Position spine routers at level 3
        spine_y = 6
        for i, spine in enumerate(spine_routers):
            if len(spine_routers) == 1:
                pos[spine] = (len(servers) // 2, spine_y)
            else:
                pos[spine] = (i * 6, spine_y)
        
        # Position top-level spine router at the top (level 4)
        top_spine_y = 8
        for i, top_spine in enumerate(top_spine_routers):
            pos[top_spine] = (len(servers) // 2, top_spine_y)
        
        # Position any other routers (fallback)
        other_y = 5
        for i, router in enumerate(other_routers):
            pos[router] = (i * 3, other_y)
        
        return pos
    
    def draw_topology(self, topology_graph, title: str = "Network Topology", 
                     layout_type: str = "circular", figsize: Tuple[int, int] = (12, 10)):
        """
        Draw the topology with group-aware coloring and layout.
        
        Args:
            topology_graph: The topology graph object (DragonFlyGraph or FatTreeGraph)
            title: Title for the plot
            layout_type: "circular" or "hierarchical"
            figsize: Figure size tuple
        """
        graph = topology_graph.graph
        
        # Choose layout
        if layout_type == "hierarchical":
            pos = self.hierarchical_layout(graph)
        else:
            pos = self.grouped_circular_layout(graph)
        
        # Classify edges
        server_to_router, intra_group, inter_group = self.classify_edges(graph)
        
        # Prepare node colors and sizes
        node_colors = []
        node_sizes = []
        
        for node in graph.nodes():
            node_data = graph.nodes[node]
            node_type = node_data.get('type', 'unknown')
            
            if node_type == 'server':
                # Color servers by group
                group_id = node_data.get('group', 0)
                node_colors.append(self.server_color_palette(group_id % 10))
                node_sizes.append(400)
            elif node_type == 'edge_router':
                # Color edge routers by group with lighter shade
                group_id = node_data.get('group', 0)
                color = self.server_color_palette(group_id % 10)
                # Make router color lighter
                node_colors.append((*color[:3], 0.7))  # Add transparency
                node_sizes.append(500)
            elif node_type == 'aggregate_router':
                # Color aggregate routers by group with medium shade
                group_id = node_data.get('group', 0)
                color = self.server_color_palette(group_id % 10)
                # Make router color medium
                node_colors.append((*color[:3], 0.8))  # Add transparency
                node_sizes.append(600)
            elif node_type == 'spine_router':
                # Color spine routers distinctly
                node_colors.append('gold')
                node_sizes.append(800)
            elif node_type == 'top_spine_router':
                # Color top-level spine router distinctly
                node_colors.append('red')
                node_sizes.append(900)
            elif node_type == 'router':
                # Color regular routers by group with lighter shade
                group_id = node_data.get('group', 0)
                color = self.server_color_palette(group_id % 10)
                # Make router color lighter
                node_colors.append((*color[:3], 0.6))  # Add transparency
                node_sizes.append(600)
            else:
                # Other special nodes
                node_colors.append('orange')
                node_sizes.append(700)
        
        # Create the plot
        plt.figure(figsize=figsize)
        
        # Draw nodes
        nx.draw_networkx_nodes(graph, pos, node_color=node_colors, 
                             node_size=node_sizes, alpha=0.8)
        
        # Draw edges with different colors
        if server_to_router:
            nx.draw_networkx_edges(graph, pos, edgelist=server_to_router, 
                                 edge_color='lightblue', width=1.0, alpha=0.6,
                                 style='dashed')
        
        if intra_group:
            nx.draw_networkx_edges(graph, pos, edgelist=intra_group, 
                                 edge_color='gray', width=1.5, alpha=0.7)
        
        if inter_group:
            nx.draw_networkx_edges(graph, pos, edgelist=inter_group, 
                                 edge_color='red', width=2.5, alpha=0.8)
        
        # Add labels
        labels = {}
        for node in graph.nodes():
            node_data = graph.nodes[node]
            if node_data.get('type') == 'server':
                server_id = node_data.get('server_id', node.split('_')[-1])
                group_id = node_data.get('group', '?')
                labels[node] = f"S{server_id}\nG{group_id}"
            elif node_data.get('type') == 'edge_router':
                group_id = node_data.get('group', '?')
                router_id = node_data.get('router', '?')
                labels[node] = f"ER{group_id}-{router_id}"
            elif node_data.get('type') == 'aggregate_router':
                group_id = node_data.get('group', '?')
                router_id = node_data.get('router', '?')
                labels[node] = f"AR{group_id}-{router_id}"
            elif node_data.get('type') == 'spine_router':
                router_id = node_data.get('router', '?')
                labels[node] = f"SPINE{router_id}"
            elif node_data.get('type') == 'top_spine_router':
                labels[node] = "TOP_SPINE"
            elif 'spine' in node:
                labels[node] = "SPINE"
            else:
                group_id = node_data.get('group', '?')
                router_id = node_data.get('router', '?')
                labels[node] = f"R{group_id}-{router_id}"
        
        nx.draw_networkx_labels(graph, pos, labels=labels, font_size=8, font_weight='bold')
        
        # Add legend
        legend_elements = []
        if server_to_router:
            legend_elements.append(plt.Line2D([0], [0], color='lightblue', 
                                            linestyle='--', label='Server-Router'))
        if intra_group:
            legend_elements.append(plt.Line2D([0], [0], color='gray', 
                                            linewidth=2, label='Intra-group'))
        if inter_group:
            legend_elements.append(plt.Line2D([0], [0], color='red', 
                                            linewidth=3, label='Inter-group'))
        
        if legend_elements:
            plt.legend(handles=legend_elements, loc='upper right')
        
        plt.title(title, fontsize=16, fontweight='bold')
        plt.axis('off')
        plt.tight_layout()
        plt.show()
    
    def visualize_all_topologies(self, num_servers: int = 16):
        """
        Visualize all available topology types for comparison.
        """
        topologies = {
            'dragonfly': 'DragonFly (Dense)',
            'dragonfly_sparse': 'DragonFly (Sparse)',
            'fat_tree': 'Fat Tree (Full)',
            'fat_tree_2_level': 'Fat Tree (2-Level)'
        }
        
        for topology_type, title in topologies.items():
            try:
                if topology_type == 'dragonfly':
                    topo = DragonFlyGraph(num_servers)
                    layout = "circular"
                elif topology_type == 'dragonfly_sparse':
                    topo = DragonFlyGraph(num_servers, num_diameter_links=1)
                    layout = "circular"
                elif topology_type == 'fat_tree':
                    topo = FatTreeGraph(num_servers, share_connects_per_spine=1.0)
                    layout = "hierarchical"
                elif topology_type == 'fat_tree_2_level':
                    topo = FatTreeGraph(num_servers, share_connects_per_spine=0.25)
                    layout = "hierarchical"
                
                self.draw_topology(topo, f"{title} ({num_servers} servers)", layout_type=layout)
                plt.savefig(f"{title}_{num_servers}.png", dpi=300)
                
            except Exception as e:
                raise e
                print(f"Error visualizing {topology_type}: {e}")
                continue


def demo_visualizations():
    """
    Demonstrate the topology visualizations.
    """
    visualizer = TopologyVisualizer()
    
    print("Generating topology visualizations...")
    
    # Show all topologies with 16 servers
    visualizer.visualize_all_topologies(num_servers=2)
    
    # # Show detailed examples with different server counts
    # print("\nGenerating detailed examples...")
    
    # # DragonFly comparison
    # plt.figure(figsize=(15, 5))
    
    # for i, (sparse, title) in enumerate([(False, "Dense DragonFly"), (True, "Sparse DragonFly")]):
    #     plt.subplot(1, 2, i+1)
        
    #     if sparse:
    #         topo = DragonFlyGraph(32, num_diameter_links=1)
    #     else:
    #         topo = DragonFlyGraph(32)
        
    #     vis = TopologyVisualizer()
    #     pos = vis.grouped_circular_layout(topo.graph)
    #     server_to_router, intra_group, inter_group = vis.classify_edges(topo.graph)
        
    #     # Simple drawing for comparison
    #     nx.draw_networkx_nodes(topo.graph, pos, node_color='lightblue', node_size=100)
    #     nx.draw_networkx_edges(topo.graph, pos, edgelist=intra_group, 
    #                          edge_color='gray', width=1, alpha=0.7)
    #     nx.draw_networkx_edges(topo.graph, pos, edgelist=inter_group, 
    #                          edge_color='red', width=2, alpha=0.8)
        
    #     plt.title(title)
    #     plt.axis('off')
        
    #     # plt.savefig(f"{title}_{'sparse' if sparse else 'dense'}.png", dpi=300)
    
    # plt.tight_layout()
    # plt.show()


if __name__ == "__main__":
    # Run the demonstration
    demo_visualizations() 
    