"""
Visualization module for MoE (Mixture of Experts) placement over DragonFly topology
"""

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import networkx as nx
from typing import Dict, List, Tuple, Any, Optional
from collections import defaultdict
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap, Normalize
import matplotlib.cm as cm

from topologies import DragonFlyGraph, prepare_network_topology


class PlacementVisualizer:
    """
    Visualizer for MoE placement over DragonFly network topology
    """
    
    def __init__(self, topology: nx.Graph, distance_matrix: np.ndarray, num_servers: int = 32):
        """
        Initialize the visualizer with network topology
        
        Args:
            num_servers: Number of servers in the topology
        """
        self.num_servers = num_servers
        self.topology = topology
        self.distance_matrix = distance_matrix

        # Set up color schemes
        self.layer_colors = {
            'attention': '#b3cde3',  # Pale blue for attention
            'moe': '#fbb4ae'        # Pale red/pink for MoE
        }

        # Create a color palette for experts
        self.expert_colors = plt.cm.Set3(np.linspace(0, 1, 32))  # Up to 32 different colors for layers
        
    def plot_placement_heatmap(self, expert_placements: List[Dict[str, Any]], 
                              layer_placements: List[Dict[str, Any]],
                              figsize: Tuple[int, int] = (16, 10)) -> Tuple[plt.Figure, plt.Axes]:
        """
        Plot placement heatmap showing how attention layers and experts are distributed across servers
        
        Args:
            expert_placements: List of expert placement dictionaries
            layer_placements: List of layer placement dictionaries
            
        Returns:
            Tuple of (figure, axes) objects
        """
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        
        # Get all unique layers and sort them
        attention_layers = [l for l in layer_placements if l['layer_type'] == 'attention']
        moe_layers = [l for l in layer_placements if l['layer_type'] == 'moe']
        
        attention_layer_ids = sorted(set(l['layer_id'] for l in attention_layers))
        moe_layer_ids = sorted(set(l['layer_id'] for l in moe_layers))
        
        all_layer_ids = attention_layer_ids + moe_layer_ids
        num_layers = len(all_layer_ids)
        
        # Create the placement matrix: rows = layers, columns = servers
        placement_matrix = np.zeros((num_layers, self.num_servers))
        layer_types = []
        layer_labels = []
        
        # Fill matrix for attention layers (1 = placed, 0 = not placed)
        for i, layer_id in enumerate(attention_layer_ids):
            layer_types.append('attention')
            layer_labels.append(f'A-{layer_id}')
            # Find which server this attention layer is placed on
            attention_layer = next(l for l in attention_layers if l['layer_id'] == layer_id)
            server_id = attention_layer['server_id']
            placement_matrix[i, server_id] = 1.0
        
        # Fill matrix for MoE layers (use expert count per server)
        for i, layer_id in enumerate(moe_layer_ids):
            row_idx = len(attention_layer_ids) + i
            layer_types.append('moe')
            layer_labels.append(f'M-{layer_id}')
            
            # Count experts per server for this MoE layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer_id]
            server_expert_counts = defaultdict(int)
            
            for expert in layer_experts:
                server_expert_counts[expert['server_id']] += 1
            
            # Fill the row with expert counts (normalize by max count for color intensity)
            max_experts = max(server_expert_counts.values()) if server_expert_counts else 1
            for server_id, count in server_expert_counts.items():
                placement_matrix[row_idx, server_id] = count / max_experts
        
        # Create custom colormap - different colors for attention vs MoE
        # Use Blues for attention (binary: 0 or 1) and Reds for MoE (continuous: 0 to 1)
        cmap_data = np.zeros((num_layers, self.num_servers, 4))  # RGBA
        
        for i in range(num_layers):
            for j in range(self.num_servers):
                value = placement_matrix[i, j]
                if layer_types[i] == 'attention':
                    if value > 0:
                        cmap_data[i, j] = [0.2, 0.4, 0.8, 0.9]  # Blue for attention
                    else:
                        cmap_data[i, j] = [1.0, 1.0, 1.0, 0.1]  # Very light gray for empty
                else:  # MoE layer
                    if value > 0:
                        # Red intensity based on expert count
                        intensity = value
                        cmap_data[i, j] = [0.8, 0.2, 0.2, 0.3 + 0.7 * intensity]  # Red with varying intensity
                    else:
                        cmap_data[i, j] = [1.0, 1.0, 1.0, 0.1]  # Very light gray for empty
        
        # Display the heatmap
        ax.imshow(cmap_data, aspect='auto', interpolation='nearest')
        
        # Add text annotations for values
        for i in range(num_layers):
            for j in range(self.num_servers):
                value = placement_matrix[i, j]
                if value > 0:
                    if layer_types[i] == 'attention':
                        text = '✓'
                        color = 'white'
                    else:  # MoE layer
                        # Show actual expert count
                        layer_id = moe_layer_ids[i - len(attention_layer_ids)]
                        layer_experts = [e for e in expert_placements if e['layer_id'] == layer_id]
                        actual_count = len([e for e in layer_experts if e['server_id'] == j])
                        text = str(actual_count)
                        color = 'white' if value > 0.5 else 'black'
                    
                    ax.text(j, i, text, ha='center', va='center', 
                           color=color, fontweight='bold', fontsize=8)
        
        # Set labels
        ax.set_xticks(range(self.num_servers))
        ax.set_xticklabels([f'S{i}' for i in range(self.num_servers)], fontsize=8)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels(layer_labels, fontsize=9)
        
        ax.set_xlabel('Server ID', fontsize=12, fontweight='bold')
        ax.set_ylabel('Layer ID', fontsize=12, fontweight='bold')
        ax.set_title('Layer and Expert Placement Heatmap', fontsize=14, fontweight='bold')
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=[0.2, 0.4, 0.8, 0.9], label='Attention Layer Placement'),
            Patch(facecolor=[0.8, 0.2, 0.2, 0.7], label='MoE Expert Count (High)'),
            Patch(facecolor=[0.8, 0.2, 0.2, 0.3], label='MoE Expert Count (Low)')
        ]
        ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1))
        
        # Add explanation text
        explanation = ("Blue: Attention layers (✓ = placed)\n"
                      "Red: MoE experts (number = expert count)\n"
                      "Intensity shows relative expert density")
        ax.text(0.02, 0.98, explanation,
               transform=ax.transAxes, fontsize=10, va='top', ha='left',
               bbox=dict(boxstyle='round,pad=0.3', facecolor='lightyellow', alpha=0.8))
        
        plt.tight_layout()
        return fig, ax

    def plot_communication_cost_heatmap(self, expert_placements: List[Dict[str, Any]], 
                                       layer_placements: List[Dict[str, Any]],
                                       figsize: Tuple[int, int] = (16, 10)) -> Tuple[plt.Figure, plt.Axes]:
        """
        Plot communication cost heatmap showing total communication distance for each expert
        
        Args:
            expert_placements: List of expert placement dictionaries
            layer_placements: List of layer placement dictionaries
            
        Returns:
            Tuple of (figure, axes) objects
        """
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        
        # Get MoE layers only
        moe_layers = [l for l in layer_placements if l['layer_type'] == 'moe']
        moe_layer_ids = sorted([l['layer_id'] for l in moe_layers])
        
        if not moe_layer_ids:
            ax.text(0.5, 0.5, 'No MoE layers found', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=16)
            ax.set_title('Communication Cost Heatmap (No MoE Layers)', fontsize=14, fontweight='bold')
            return fig, ax
        
        # Find maximum number of experts across all MoE layers for consistent matrix size
        max_experts = 0
        layer_expert_data = {}
        
        for layer_id in moe_layer_ids:
            # Get layer info
            layer_info = next(l for l in moe_layers if l['layer_id'] == layer_id)
            dispatch_server = layer_info['dispatch_server']
            collect_server = layer_info['collect_server']
            
            # Get experts for this layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer_id]
            layer_experts = sorted(layer_experts, key=lambda x: x['expert_id'])
            
            # Calculate communication costs for each expert
            expert_costs = []
            for expert in layer_experts:
                expert_server = expert['server_id']
                # Total cost: dispatch → expert → collect
                cost = (self.distance_matrix[dispatch_server][expert_server] + 
                       self.distance_matrix[expert_server][collect_server])
                expert_costs.append(cost)
            
            layer_expert_data[layer_id] = {
                'experts': layer_experts,
                'costs': expert_costs,
                'dispatch_server': dispatch_server,
                'collect_server': collect_server
            }
            
            max_experts = max(max_experts, len(layer_experts))
        
        # Create the cost matrix: rows = layers, columns = experts
        num_layers = len(moe_layer_ids)
        cost_matrix = np.full((num_layers, max_experts), np.nan)  # Use NaN for missing experts
        
        # Fill the matrix
        for i, layer_id in enumerate(moe_layer_ids):
            costs = layer_expert_data[layer_id]['costs']
            for j, cost in enumerate(costs):
                cost_matrix[i, j] = cost
        
        # Create the heatmap
        # Use a colormap that handles NaN values well
        cmap = plt.cm.viridis.copy()
        cmap.set_bad(color='lightgray', alpha=0.3)  # Color for NaN values
        
        im = ax.imshow(cost_matrix, cmap=cmap, aspect='auto', interpolation='nearest')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax, label='Communication Cost (Hops)')
        
        # Add text annotations with the cost values
        for i in range(num_layers):
            layer_id = moe_layer_ids[i]
            costs = layer_expert_data[layer_id]['costs']
            
            for j, cost in enumerate(costs):
                # Choose text color based on cost value for readability
                color = 'white' if cost > np.nanmax(cost_matrix) * 0.6 else 'black'
                ax.text(j, i, f'{int(cost)}', ha='center', va='center', 
                       color=color, fontweight='bold', fontsize=8)
        
        # Set labels
        ax.set_xticks(range(max_experts))
        ax.set_xticklabels([f'E{i}' for i in range(max_experts)], fontsize=8)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels([f'L{layer_id}' for layer_id in moe_layer_ids], fontsize=9)
        
        ax.set_xlabel('Expert ID', fontsize=12, fontweight='bold')
        ax.set_ylabel('MoE Layer ID', fontsize=12, fontweight='bold')
        ax.set_title('Communication Cost Heatmap (Dispatch → Expert → Collect)', 
                    fontsize=14, fontweight='bold')
        
        # Add summary statistics
        all_costs = [cost for costs in layer_expert_data.values() for cost in costs['costs']]
        if all_costs:
            min_cost = min(all_costs)
            max_cost = max(all_costs)
            avg_cost = np.mean(all_costs)
            
            stats_text = f"Cost Statistics:\nMin: {min_cost:.1f} hops\nMax: {max_cost:.1f} hops\nAvg: {avg_cost:.1f} hops"
            ax.text(0.02, 0.98, stats_text,
                   transform=ax.transAxes, fontsize=10, va='top', ha='left',
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='lightyellow', alpha=0.8))
        
        # Add explanation
        explanation = ("Each cell shows total communication cost:\n"
                      "dispatch_server → expert_server → collect_server\n"
                      "Gray cells: No expert at this position")
        ax.text(0.98, 0.02, explanation,
               transform=ax.transAxes, fontsize=9, va='bottom', ha='right',
               bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.8))
        
        plt.tight_layout()
        return fig, ax

    def plot_communication_heatmap(self, expert_placements: List[Dict[str, Any]], 
                                 layer_placements: List[Dict[str, Any]],
                                 figsize: Tuple[int, int] = (12, 10)) -> Tuple[plt.Figure, plt.Axes]:
        """
        Plot communication pattern heatmap
        
        Args:
            expert_placements: List of expert placement dictionaries
            layer_placements: List of layer placement dictionaries
            
        Returns:
            Tuple of (figure, axes) objects
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # Plot 1: Distance matrix
        im1 = ax1.imshow(self.distance_matrix, cmap='viridis', aspect='auto')
        ax1.set_title('Network Distance Matrix', fontweight='bold')
        ax1.set_xlabel('Server ID')
        ax1.set_ylabel('Server ID')
        plt.colorbar(im1, ax=ax1, label='Hops')
        
        # Plot 2: Communication intensity matrix
        comm_matrix = np.zeros((self.num_servers, self.num_servers))
        
        # Get MoE layers and their dispatch/collect servers
        moe_layers = [l for l in layer_placements if l['layer_type'] == 'moe']
        
        for layer in moe_layers:
            dispatch_server = layer['dispatch_server']
            collect_server = layer['collect_server']
            
            # Find experts for this layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer['layer_id']]
            
            for expert in layer_experts:
                expert_server = expert['server_id']
                # Add communication from dispatch to expert and expert to collect
                comm_matrix[dispatch_server][expert_server] += 1
                comm_matrix[expert_server][collect_server] += 1
        
        im2 = ax2.imshow(comm_matrix, cmap='Reds', aspect='auto')
        ax2.set_title('Communication Intensity Matrix', fontweight='bold')
        ax2.set_xlabel('Destination Server ID')
        ax2.set_ylabel('Source Server ID')
        plt.colorbar(im2, ax=ax2, label='Communication Count')
        
        plt.tight_layout()
        return fig, (ax1, ax2)
    
    def plot_complete_visualization(self, expert_placements: List[Dict[str, Any]], 
                                  layer_placements: List[Dict[str, Any]],
                                  figsize: Tuple[int, int] = (28, 20)) -> plt.Figure:
        """
        Create a complete visualization with all components
        
        Args:
            expert_placements: List of expert placement dictionaries
            layer_placements: List of layer placement dictionaries
            
        Returns:
            Figure object with subplots
        """
        fig = plt.figure(figsize=figsize)
        
        # Create a 3x3 grid to accommodate all visualizations
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.2, 
                             width_ratios=[1, 1, 1.2], height_ratios=[1, 1, 1])
        
        # Topology structure
        ax1 = fig.add_subplot(gs[0, 0])
        pos = self._compute_group_layout()
        self._draw_basic_topology(ax1, pos, alpha=1.0, show_labels=True)
        ax1.set_title('Network Topology', fontweight='bold')
        ax1.axis('off')
        
        # Placement heatmap (top right)
        ax_heatmap = fig.add_subplot(gs[0, 2])
        self._plot_placement_heatmap_on_ax(ax_heatmap, expert_placements, layer_placements)
        ax_heatmap.set_title('Placement Heatmap', fontweight='bold')
        
        # Statistics
        ax4 = fig.add_subplot(gs[1, 1])
        self._plot_placement_statistics(ax4, expert_placements, layer_placements)
        ax4.set_title('Placement Statistics', fontweight='bold')
        
        # Communication cost heatmap (middle right)
        ax_comm_cost = fig.add_subplot(gs[1, 2])
        self._plot_communication_cost_heatmap_on_ax(ax_comm_cost, expert_placements, layer_placements)
        ax_comm_cost.set_title('Communication Cost Heatmap', fontweight='bold')
        
        # Communication patterns (bottom row, spanning all columns)
        ax_comm = fig.add_subplot(gs[2, :])
        self._plot_communication_pattern_on_ax(ax_comm, expert_placements, layer_placements)
        ax_comm.set_title('Communication Patterns', fontweight='bold')
        
        plt.suptitle('MoE Placement Visualization Dashboard', fontsize=20, fontweight='bold')
        
        return fig
    
    def _compute_group_layout(self) -> Dict[str, Tuple[float, float]]:
        """
        Compute layout positions that respect group structure
        """
        pos = {}
        group_size = self.topology.group_size
        num_groups = self.topology.num_groups
        
        # Arrange groups in a circle
        group_radius = 3.0
        node_spacing = 0.8
        
        for group in range(num_groups):
            # Group center
            angle = 2 * np.pi * group / num_groups
            group_x = group_radius * np.cos(angle)
            group_y = group_radius * np.sin(angle)
            
            # Arrange nodes within group
            for router in range(group_size):
                if f"router_{group}_{router}" in self.topology.routers:
                    # Router position
                    router_angle = 2 * np.pi * router / group_size
                    router_x = group_x + node_spacing * np.cos(router_angle)
                    router_y = group_y + node_spacing * np.sin(router_angle)
                    pos[f"router_{group}_{router}"] = (router_x, router_y)
                    
                    # Server position (slightly offset from router)
                    server_id = group * group_size + router
                    if server_id < self.num_servers:
                        server_x = router_x + 0.3 * np.cos(router_angle)
                        server_y = router_y + 0.3 * np.sin(router_angle)
                        pos[f"server_{server_id}"] = (server_x, server_y)
        
        return pos
    
    def _draw_group_boundaries(self, ax: plt.Axes, pos: Dict[str, Tuple[float, float]]):
        """Draw boundaries around groups"""
        group_size = self.topology.group_size
        num_groups = self.topology.num_groups
        
        for group in range(num_groups):
            # Find all nodes in this group
            group_nodes = []
            for router in range(group_size):
                server_id = group * group_size + router
                if server_id < self.num_servers and f"server_{server_id}" in pos:
                    group_nodes.append(pos[f"server_{server_id}"])
                if f"router_{group}_{router}" in pos:
                    group_nodes.append(pos[f"router_{group}_{router}"])
            
            if group_nodes:
                # Draw circle around group
                center_x = np.mean([p[0] for p in group_nodes])
                center_y = np.mean([p[1] for p in group_nodes])
                radius = max([np.sqrt((p[0] - center_x)**2 + (p[1] - center_y)**2) 
                            for p in group_nodes]) + 0.3
                
                circle = patches.Circle((center_x, center_y), radius, 
                                      fill=False, edgecolor='gray', 
                                      linestyle='--', alpha=0.5, linewidth=2)
                ax.add_patch(circle)
                
                # Add group label
                ax.text(center_x, center_y + radius + 0.2, f'Group {group}',
                       ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    def _draw_basic_topology(self, ax: plt.Axes, pos: Dict[str, Tuple[float, float]], 
                           alpha: float = 0.3, show_labels: bool = False):
        """Draw basic topology structure"""
        # Draw edges
        router_edges = [(u, v) for u, v in self.topology.graph.edges() 
                       if self.topology.graph.nodes[u]['type'] == 'router' 
                       and self.topology.graph.nodes[v]['type'] == 'router']
        server_edges = [(u, v) for u, v in self.topology.graph.edges()
                       if (self.topology.graph.nodes[u]['type'] == 'server' 
                           and self.topology.graph.nodes[v]['type'] == 'router')
                       or (self.topology.graph.nodes[u]['type'] == 'router' 
                           and self.topology.graph.nodes[v]['type'] == 'server')]
        
        nx.draw_networkx_edges(self.topology.graph, pos, edgelist=router_edges, 
                              edge_color='gray', alpha=alpha*0.5, width=1, ax=ax)
        nx.draw_networkx_edges(self.topology.graph, pos, edgelist=server_edges,
                              edge_color='lightblue', alpha=alpha, width=0.5, ax=ax)
        
        # Draw nodes
        server_nodes = [n for n in self.topology.graph.nodes() 
                       if self.topology.graph.nodes[n]['type'] == 'server']
        router_nodes = [n for n in self.topology.graph.nodes() 
                       if self.topology.graph.nodes[n]['type'] == 'router']
        
        nx.draw_networkx_nodes(self.topology.graph, pos, nodelist=router_nodes,
                              node_color='orange', node_size=100, alpha=alpha,
                              node_shape='o', ax=ax)
        
        if show_labels:
            server_labels = {n: str(self.topology.graph.nodes[n]['server_id']) 
                           for n in server_nodes}
            nx.draw_networkx_labels(self.topology.graph, pos, labels=server_labels,
                                   font_size=6, ax=ax)
    
    def _plot_placement_heatmap_on_ax(self, ax: plt.Axes, expert_placements: List[Dict[str, Any]], 
                                     layer_placements: List[Dict[str, Any]]):
        """Plot placement heatmap on given axes"""
        # Get all unique layers and sort them
        attention_layers = [l for l in layer_placements if l['layer_type'] == 'attention']
        moe_layers = [l for l in layer_placements if l['layer_type'] == 'moe']
        
        attention_layer_ids = sorted(set(l['layer_id'] for l in attention_layers))
        moe_layer_ids = sorted(set(l['layer_id'] for l in moe_layers))
        
        all_layer_ids = attention_layer_ids + moe_layer_ids
        num_layers = len(all_layer_ids)
        
        # Create the placement matrix: rows = layers, columns = servers
        placement_matrix = np.zeros((num_layers, self.num_servers))
        layer_types = []
        layer_labels = []
        
        # Fill matrix for attention layers (1 = placed, 0 = not placed)
        for i, layer_id in enumerate(attention_layer_ids):
            layer_types.append('attention')
            layer_labels.append(f'A-{layer_id}')
            # Find which server this attention layer is placed on
            attention_layer = next(l for l in attention_layers if l['layer_id'] == layer_id)
            server_id = attention_layer['server_id']
            placement_matrix[i, server_id] = 1.0
        
        # Fill matrix for MoE layers (use expert count per server)
        for i, layer_id in enumerate(moe_layer_ids):
            row_idx = len(attention_layer_ids) + i
            layer_types.append('moe')
            layer_labels.append(f'M-{layer_id}')
            
            # Count experts per server for this MoE layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer_id]
            server_expert_counts = defaultdict(int)
            
            for expert in layer_experts:
                server_expert_counts[expert['server_id']] += 1
            
            # Fill the row with expert counts (normalize by max count for color intensity)
            max_experts = max(server_expert_counts.values()) if server_expert_counts else 1
            for server_id, count in server_expert_counts.items():
                placement_matrix[row_idx, server_id] = count / max_experts
        
        # Create custom colormap - different colors for attention vs MoE
        cmap_data = np.zeros((num_layers, self.num_servers, 4))  # RGBA
        
        for i in range(num_layers):
            for j in range(self.num_servers):
                value = placement_matrix[i, j]
                if layer_types[i] == 'attention':
                    if value > 0:
                        cmap_data[i, j] = [0.2, 0.4, 0.8, 0.9]  # Blue for attention
                    else:
                        cmap_data[i, j] = [1.0, 1.0, 1.0, 0.1]  # Very light gray for empty
                else:  # MoE layer
                    if value > 0:
                        # Red intensity based on expert count
                        intensity = value
                        cmap_data[i, j] = [0.8, 0.2, 0.2, 0.3 + 0.7 * intensity]  # Red with varying intensity
                    else:
                        cmap_data[i, j] = [1.0, 1.0, 1.0, 0.1]  # Very light gray for empty
        
        # Display the heatmap
        ax.imshow(cmap_data, aspect='auto', interpolation='nearest')
        
        # Add text annotations for values
        for i in range(num_layers):
            for j in range(self.num_servers):
                value = placement_matrix[i, j]
                if value > 0:
                    if layer_types[i] == 'attention':
                        text = '✓'
                        color = 'white'
                    else:  # MoE layer
                        # Show actual expert count
                        layer_id = moe_layer_ids[i - len(attention_layer_ids)]
                        layer_experts = [e for e in expert_placements if e['layer_id'] == layer_id]
                        actual_count = len([e for e in layer_experts if e['server_id'] == j])
                        text = str(actual_count)
                        color = 'white' if value > 0.5 else 'black'
                    
                    ax.text(j, i, text, ha='center', va='center', 
                           color=color, fontweight='bold', fontsize=7)
        
        # Set labels with smaller font for dashboard
        ax.set_xticks(range(self.num_servers))
        ax.set_xticklabels([f'S{i}' for i in range(self.num_servers)], fontsize=7)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels(layer_labels, fontsize=8)
        
        ax.set_xlabel('Server ID', fontsize=10, fontweight='bold')
        ax.set_ylabel('Layer ID', fontsize=10, fontweight='bold')
        
        # Add compact legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=[0.2, 0.4, 0.8, 0.9], label='Attention'),
            Patch(facecolor=[0.8, 0.2, 0.2, 0.7], label='MoE (High)'),
            Patch(facecolor=[0.8, 0.2, 0.2, 0.3], label='MoE (Low)')
        ]
        ax.legend(handles=legend_elements, loc='upper right', fontsize=8)

    def _plot_communication_cost_heatmap_on_ax(self, ax: plt.Axes, expert_placements: List[Dict[str, Any]], 
                                              layer_placements: List[Dict[str, Any]]):
        """Plot communication cost heatmap on given axes (compact version for dashboard)"""
        # Get MoE layers only
        moe_layers = [l for l in layer_placements if l['layer_type'] == 'moe']
        moe_layer_ids = sorted([l['layer_id'] for l in moe_layers])
        
        if not moe_layer_ids:
            ax.text(0.5, 0.5, 'No MoE layers found', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=12)
            return
        
        # Find maximum number of experts across all MoE layers for consistent matrix size
        max_experts = 0
        layer_expert_data = {}
        
        for layer_id in moe_layer_ids:
            # Get layer info
            layer_info = next(l for l in moe_layers if l['layer_id'] == layer_id)
            dispatch_server = layer_info['dispatch_server']
            collect_server = layer_info['collect_server']
            
            # Get experts for this layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer_id]
            layer_experts = sorted(layer_experts, key=lambda x: x['expert_id'])
            
            # Calculate communication costs for each expert
            expert_costs = []
            for expert in layer_experts:
                expert_server = expert['server_id']
                # Total cost: dispatch → expert → collect
                cost = (self.distance_matrix[dispatch_server][expert_server] + 
                       self.distance_matrix[expert_server][collect_server])
                expert_costs.append(cost)
            
            layer_expert_data[layer_id] = {
                'experts': layer_experts,
                'costs': expert_costs,
                'dispatch_server': dispatch_server,
                'collect_server': collect_server
            }
            
            max_experts = max(max_experts, len(layer_experts))
        
        # Limit to reasonable size for dashboard
        max_experts = min(max_experts, 20)  # Limit columns for readability
        num_layers = min(len(moe_layer_ids), 10)  # Limit rows for readability
        moe_layer_ids = moe_layer_ids[:num_layers]
        
        # Create the cost matrix: rows = layers, columns = experts
        cost_matrix = np.full((num_layers, max_experts), np.nan)  # Use NaN for missing experts
        
        # Fill the matrix
        for i, layer_id in enumerate(moe_layer_ids):
            costs = layer_expert_data[layer_id]['costs']
            for j, cost in enumerate(costs[:max_experts]):  # Limit to max_experts
                cost_matrix[i, j] = cost
        
        # Create the heatmap
        cmap = plt.cm.viridis.copy()
        cmap.set_bad(color='lightgray', alpha=0.3)  # Color for NaN values
        
        im = ax.imshow(cost_matrix, cmap=cmap, aspect='auto', interpolation='nearest')
        
        # Add text annotations with the cost values (smaller font for dashboard)
        for i in range(num_layers):
            layer_id = moe_layer_ids[i]
            costs = layer_expert_data[layer_id]['costs']
            
            for j, cost in enumerate(costs[:max_experts]):
                # Choose text color based on cost value for readability
                color = 'white' if cost > np.nanmax(cost_matrix) * 0.6 else 'black'
                ax.text(j, i, f'{int(cost)}', ha='center', va='center', 
                       color=color, fontweight='bold', fontsize=6)
        
        # Set labels with smaller font for dashboard
        ax.set_xticks(range(0, max_experts, 2))  # Show every other expert
        ax.set_xticklabels([f'E{i}' for i in range(0, max_experts, 2)], fontsize=7)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels([f'L{layer_id}' for layer_id in moe_layer_ids], fontsize=8)
        
        ax.set_xlabel('Expert ID', fontsize=10, fontweight='bold')
        ax.set_ylabel('MoE Layer', fontsize=10, fontweight='bold')
        
        # Add compact colorbar
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax, label='Hops')

    def _plot_communication_pattern_on_ax(self, ax: plt.Axes, expert_placements: List[Dict[str, Any]], 
                                         layer_placements: List[Dict[str, Any]]):
        """Plot communication patterns on given axes (compact version for dashboard)"""
        # Just show communication intensity matrix for dashboard simplicity
        comm_matrix = np.zeros((self.num_servers, self.num_servers))
        
        # Get MoE layers and their dispatch/collect servers
        moe_layers = [l for l in layer_placements if l['layer_type'] == 'moe']
        
        for layer in moe_layers:
            dispatch_server = layer['dispatch_server']
            collect_server = layer['collect_server']
            
            # Find experts for this layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer['layer_id']]
            
            for expert in layer_experts:
                expert_server = expert['server_id']
                # Add communication from dispatch to expert and expert to collect
                comm_matrix[dispatch_server][expert_server] += 1
                comm_matrix[expert_server][collect_server] += 1
        
        im = ax.imshow(comm_matrix, cmap='Reds', aspect='auto')
        ax.set_xlabel('Destination Server ID', fontsize=10)
        ax.set_ylabel('Source Server ID', fontsize=10)
        ax.tick_params(labelsize=8)
        
        # Add colorbar
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="2%", pad=0.05)
        plt.colorbar(im, cax=cax, label='Count')

    def _plot_placement_statistics(self, ax: plt.Axes, expert_placements: List[Dict[str, Any]], 
                                 layer_placements: List[Dict[str, Any]]):
        """Plot placement statistics"""
        # Calculate statistics
        server_expert_count = defaultdict(int)
        server_layer_count = defaultdict(int)
        layer_type_count = defaultdict(int)
        
        for expert in expert_placements:
            server_expert_count[expert['server_id']] += 1
        
        for layer in layer_placements:
            server_layer_count[layer['server_id']] += 1
            layer_type_count[layer['layer_type']] += 1
        
        # Create subplots for statistics
        ax.clear()
        
        # Expert distribution histogram
        expert_counts = list(server_expert_count.values())
        if expert_counts:
            ax.hist(expert_counts, bins=10, alpha=0.7, color='skyblue', edgecolor='black')
            ax.set_xlabel('Experts per Server')
            ax.set_ylabel('Number of Servers')
            ax.grid(True, alpha=0.3)
            
            # Add statistics text
            stats_text = f"""
Statistics:
• Total Experts: {len(expert_placements)}
• Total Layers: {len(layer_placements)}
• Attention Layers: {layer_type_count.get('attention', 0)}
• MoE Layers: {layer_type_count.get('moe', 0)}
• Servers Used: {len(server_expert_count)}
• Max Experts/Server: {max(expert_counts) if expert_counts else 0}
• Avg Experts/Server: {np.mean(expert_counts):.1f}
            """
            ax.text(0.95, 0.95, stats_text.strip(), transform=ax.transAxes,
                   verticalalignment='top', horizontalalignment='right',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
                   fontsize=8)


def visualize_moe_placement(expert_placements: List[Dict[str, Any]], 
                          layer_placements: List[Dict[str, Any]],
                          num_servers: int = 32,
                          save_path: Optional[str] = None) -> plt.Figure:
    """
    Convenience function to create complete MoE placement visualization
    
    Args:
        expert_placements: List of expert placement dictionaries
        layer_placements: List of layer placement dictionaries
        num_servers: Number of servers in topology
        save_path: Optional path to save the figure
        
    Returns:
        Figure object
    """
    visualizer = MoETopologyVisualizer(num_servers)
    fig = visualizer.plot_complete_visualization(expert_placements, layer_placements)
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig


# # Example usage and testing
# if __name__ == "__main__":
#     # Example with dummy data
#     from lpopt import construct_moe_placement
    
#     num_servers = 32
#     num_layers = 27
#     experts_per_layer = 64
#     max_experts_per_server = 27*2
#     max_layers_per_server = 2
#     max_layer_experts_per_server = 6
    
#     # Create example placement using the optimal solver
#     distance_matrix, neighbor_info = prepare_network_topology(num_servers)
    
#     try:
#         expert_placements, layer_placements, server_expert_count, server_layer_count = construct_moe_placement(
#             distance_matrix=distance_matrix,
#             neighbor_info=neighbor_info,
#             num_layers=num_layers,
#             experts_per_layer=experts_per_layer,
#             max_experts_per_server=max_experts_per_server,
#             max_layers_per_server=max_layers_per_server,
#             max_layer_experts_per_server=max_layer_experts_per_server,
#             random_seed=42
#         )
        
#         # Create visualization
#         fig = visualize_moe_placement(expert_placements, layer_placements, num_servers=num_servers)
#         plt.show()
        
#         # Create individual plots
#         visualizer = MoETopologyVisualizer(num_servers)
        
#         # Plot communication heatmap
#         fig4, (ax4, ax5) = visualizer.plot_communication_heatmap(expert_placements, layer_placements)
#         plt.savefig("visualize/communication_heatmap.png")
#         plt.show()
        
#     except ImportError as e:
#         print(f"Error: {e}")
#         print("Please install PuLP to run the optimal placement: pip install pulp")
        
#         # Create visualization with dummy data instead
#         print("Creating visualization with dummy data...")
        
#         # Simple dummy placement
#         expert_placements = []
#         for layer_id in range(1, num_layers, 2):  # MoE layers (odd indices)
#             for expert_id in range(num_layers):
#                 expert_placements.append({
#                     'expert_id': expert_id,
#                     'layer_id': layer_id,
#                     'server_id': expert_id % num_servers
#                 })
        
#         layer_placements = []
#         for layer_id in range(num_layers):
#             if layer_id % 2 == 0:  # Attention layers
#                 layer_placements.append({
#                     'layer_id': layer_id,
#                     'layer_type': 'attention',
#                     'server_id': (layer_id // 2) % num_servers
#                 })
#             else:  # MoE layers
#                 layer_placements.append({
#                     'layer_id': layer_id,
#                     'layer_type': 'moe',
#                     'server_id': (layer_id // 2) % num_servers,
#                     'dispatch_server': (layer_id // 2) % num_servers,
#                     'collect_server': ((layer_id // 2) + 1) % num_servers
#                 })
        
#         # Create visualization with dummy data
#         fig = visualize_moe_placement(expert_placements, layer_placements, num_servers=num_servers)
#         plt.show()

