# EVOLVE-BLOCK-START
"""
MoE (Mixture of Experts) placement optimization for datacenter networks
"""
import numpy as np
from typing import Dict, List, Tuple, Any, Optional

def _place_attentions_round_robin(num_servers: int, num_layers: int, max_per_server: int, stride: int = 1) -> Tuple[List[int], Dict[int, int]]:
    attention_servers: List[int] = []
    server_layer_count: Dict[int, int] = {sid: 0 for sid in range(num_servers)}

    for layer_id in range(0, num_layers, 2):
        attn_idx = layer_id//2
        server_id = (int(attn_idx * stride)) % num_servers
        while server_layer_count[server_id] >= max_per_server:
            server_id = (server_id + 1) % num_servers
        
        attention_servers.append(server_id)
        server_layer_count[server_id] += 1  # count the attention layer itself
    
    return attention_servers, server_layer_count

def construct_moe_placement(distance_matrix, neighbor_info, num_layers=32, experts_per_layer=32,
                           max_experts_per_server: int = 32,
                           max_layers_per_server: int = 2,
                           max_layer_experts_per_server: int = 4,
                           random_seed: Optional[int] = 42) -> Tuple[List[Dict], List[Dict], Dict[int, int], Dict[int, int]]:
    """
    Construct an optimized MoE placement for a datacenter network.
    
    This function assigns experts and layers to servers to minimize communication overhead
    while respecting server capacity constraints. It uses the distance matrix to cluster
    experts near dispatch servers and ensures per-layer capacity constraints are respected.

    Args:
        distance_matrix: Pre-computed shortest path distances between servers
        neighbor_info: Pre-computed neighbor information for each server
        num_layers (int): Number of layers in the model.
        experts_per_layer (int): Number of experts per MoE layer.
        max_experts_per_server (int): Maximum number of experts per server.
        max_layers_per_server (int): Maximum number of layers per server.
        max_layer_experts_per_server (int): Maximum number of experts per layer per server.
        random_seed (Optional[int]): Random seed for reproducibility.

    Returns:
        Tuple[
            List[Dict],                # expert_placements: List of expert placement dicts
            List[Dict],                # layer_placements: List of layer placement dicts
            Dict[int, int],            # server_expert_count: Number of experts per server
            Dict[int, int],            # server_layer_count: Number of layers per server
        ]
    """
    num_servers = distance_matrix.shape[0]
    
    if random_seed is not None:
        np.random.seed(random_seed)
    
    # Initialize placement data structures
    expert_placements = []
    layer_placements = []
    server_expert_count = {i: 0 for i in range(num_servers)}
    server_layer_count = {i: 0 for i in range(num_servers)}
    
    # # Define group parameters (assuming 4 groups of 8 servers each)
    # group_size = 8
    # group_count = num_servers // group_size
    
    # # Phase 1: Place attention layers with group-based clustering
    # attention_servers = []
    
    # for layer_id in range(num_layers):
    #     layer_type = 'moe' if layer_id % 2 == 1 else 'attention'
        
    #     if layer_type == 'attention':
    #         attention_idx = layer_id // 2
    #         # Group attention layers into groups of 4 to cluster within server groups
    #         group_idx = (attention_idx // group_count) % group_count
    #         server_id = group_idx * group_size + (attention_idx % group_size)
    #         attention_servers.append(server_id)
            
    #         layer_placements.append({
    #             'layer_id': layer_id,
    #             'layer_type': layer_type,
    #             'server_id': server_id
    #         })
            
    #         server_layer_count[server_id] += 1
    
    # stride = num_servers//(num_layers//2)
    stride = int(num_servers/(num_layers//2))
    attention_servers, server_layer_count = _place_attentions_round_robin(num_servers, num_layers, max_layers_per_server, stride)
    for layer_id in range(0, num_layers, 2):
        attn_idx = layer_id//2
        server_id = attention_servers[attn_idx]
        layer_placements.append({
            'layer_id': layer_id,
            'layer_type': 'attention',
            'server_id': server_id
        })
    
    # Phase 2: Place MoE layers with proximity-based expert assignment
    attention_idx = 0
    
    for layer_id in range(num_layers):
        layer_type = 'moe' if layer_id % 2 == 1 else 'attention'
        
        if layer_type == 'moe':
            attention_idx = layer_id // 2
            # Get dispatch and collect servers from adjacent attention layers
            dispatch_server = attention_servers[attention_idx] if attention_idx < len(attention_servers) else attention_servers[0]
            collect_server = attention_servers[attention_idx + 1] if (attention_idx + 1 < len(attention_servers)) else attention_servers[-1]
            
            # Generate list of servers sorted by distance to dispatch server
            servers_sorted = sorted(range(num_servers), key=lambda x: distance_matrix[dispatch_server][x]+distance_matrix[x][collect_server])
            
            layer_experts = []
            server_experts_per_layer = {server_id: 0 for server_id in range(num_servers)}
            
            for expert_id in range(experts_per_layer):
                # Find the first server in the sorted list that can take more experts
                for server_id in servers_sorted:
                    if server_experts_per_layer[server_id] < max_layer_experts_per_server:
                        if server_expert_count[server_id] < max_experts_per_server:
                            expert_placements.append({
                                'expert_id': expert_id,
                                'layer_id': layer_id,
                                'server_id': server_id
                            })
                            layer_experts.append(expert_id)
                            server_experts_per_layer[server_id] += 1
                            server_expert_count[server_id] += 1
                            break
            
            layer_placements.append({
                'layer_id': layer_id,
                'layer_type': layer_type,
                'server_id': dispatch_server,
                'experts': layer_experts,
                'dispatch_server': dispatch_server,
                'collect_server': collect_server
            })
            
            attention_idx += 1
    
    return expert_placements, layer_placements, server_expert_count, server_layer_count


# EVOLVE-BLOCK-END