# EVOLVE-BLOCK-START
"""
MoE (Mixture of Experts) placement optimization for datacenter networks
"""
import numpy as np
from typing import Dict, List, Tuple, Any, Optional


def construct_moe_placement(distance_matrix, neighbor_info, num_layers=32, experts_per_layer=32,
                           max_experts_per_server: int = 32,
                           max_layers_per_server: int = 2,
                           max_layer_experts_per_server: int = 4,
                           random_seed: Optional[int] = 42) -> Tuple[List[Dict], List[Dict], Dict[int, int], Dict[int, int]]:
    """
    Construct an optimized MoE placement for a datacenter network.
    
    This function assigns experts and layers to servers to minimize communication overhead
    while respecting server capacity constraints. It uses the distance matrix to cluster
    experts near dispatch servers and ensures per-layer capacity constraints are respected.

    Args:
        distance_matrix: Pre-computed shortest path distances between servers
        neighbor_info: Pre-computed neighbor information for each server
        num_layers (int): Number of layers in the model.
        experts_per_layer (int): Number of experts per MoE layer.
        max_experts_per_server (int): Maximum number of experts per server.
        max_layers_per_server (int): Maximum number of layers per server.
        max_layer_experts_per_server (int): Maximum number of experts per layer per server.
        random_seed (Optional[int]): Random seed for reproducibility.

    Returns:
        Tuple[
            List[Dict],                # expert_placements: List of expert placement dicts
            List[Dict],                # layer_placements: List of layer placement dicts
            Dict[int, int],            # server_expert_count: Number of experts per server
            Dict[int, int],            # server_layer_count: Number of layers per server
        ]
    """
    num_servers = distance_matrix.shape[0]
    
    if random_seed is not None:
        np.random.seed(random_seed)
    
    # Initialize placement data structures
    expert_placements = []
    layer_placements = []
    server_expert_count = {i: 0 for i in range(num_servers)}
    server_layer_count = {i: 0 for i in range(num_servers)}
    
    # Phase 1: Place attention layers with even distribution
    attention_servers = []
    
    for layer_id in range(num_layers):
        layer_type = 'moe' if layer_id % 2 == 1 else 'attention'
        
        if layer_type == 'attention':
            # Spread attention layers across the network for better load balancing
            server_id = layer_id // 2 % num_servers
            attention_servers.append(server_id)
            
            layer_placements.append({
                'layer_id': layer_id,
                'layer_type': layer_type,
                'server_id': server_id
            })
            
            server_layer_count[server_id] += 1
    
    # Phase 2: Place MoE layers with proximity-based expert assignment
    attention_idx = 0
    
    for layer_id in range(num_layers):
        layer_type = 'moe' if layer_id % 2 == 1 else 'attention'
        
        if layer_type == 'moe':
            attention_idx = layer_id // 2
            # Get dispatch and collect servers from adjacent attention layers
            dispatch_server = attention_servers[attention_idx] if attention_idx < len(attention_servers) else attention_servers[0]
            collect_server = attention_servers[attention_idx + 1] if (attention_idx + 1 < len(attention_servers)) else attention_servers[-1]
            
            # Generate list of servers sorted by distance to dispatch server
            servers_sorted = sorted(range(num_servers), key=lambda x: distance_matrix[dispatch_server][x])
            
            layer_experts = []
            server_experts_per_layer = {server_id: 0 for server_id in range(num_servers)}
            
            for expert_id in range(experts_per_layer):
                # Find the first server in the sorted list that can take more experts
                for server_id in servers_sorted:
                    if server_experts_per_layer[server_id] < max_layer_experts_per_server:
                        if server_expert_count[server_id] + 1 <= max_experts_per_server:
                            expert_placements.append({
                                'expert_id': expert_id,
                                'layer_id': layer_id,
                                'server_id': server_id
                            })
                            layer_experts.append(expert_id)
                            server_experts_per_layer[server_id] += 1
                            server_expert_count[server_id] += 1
                            break
            
            layer_placements.append({
                'layer_id': layer_id,
                'layer_type': layer_type,
                'server_id': dispatch_server,
                'experts': layer_experts,
                'dispatch_server': dispatch_server,
                'collect_server': collect_server
            })
            
            attention_idx += 1
    
    return expert_placements, layer_placements, server_expert_count, server_layer_count


# EVOLVE-BLOCK-END"