# EVOLVE-BLOCK-START
"""
MoE (Mixture of Experts) placement optimization for datacenter networks
"""
import numpy as np
from typing import Dict, List, Tuple, Any, Optional

def _place_attentions_round_robin(num_servers: int, num_layers: int, max_per_server: int, stride: int = 1) -> Tuple[List[int], Dict[int, int]]:
    attention_servers: List[int] = []
    server_layer_count: Dict[int, int] = {sid: 0 for sid in range(num_servers)}

    for layer_id in range(0, num_layers, 2):
        attn_idx = layer_id//2
        server_id = (int(attn_idx * stride)) % num_servers
        while server_layer_count[server_id] >= max_per_server:
            server_id = (server_id + 1) % num_servers
        
        attention_servers.append(server_id)
        server_layer_count[server_id] += 1  # count the attention layer itself
    
    return attention_servers, server_layer_count


def construct_moe_placement(distance_matrix, neighbor_info, num_layers=32, experts_per_layer=32,
                           max_experts_per_server: int = 32,
                           max_layers_per_server: int = 2,
                           max_layer_experts_per_server: int = 4,
                           random_seed: Optional[int] = 42) -> Tuple[List[Dict], List[Dict], Dict[int, int], Dict[int, int]]:
    """
    Construct an optimized MoE placement for a datacenter network.
    
    This function assigns experts and layers to servers to minimize communication overhead
    while respecting server capacity constraints.

    Args:
        distance_matrix: Pre-computed shortest path distances between servers
        neighbor_info: Pre-computed neighbor information for each server
        num_layers (int): Number of layers in the model.
        experts_per_layer (int): Number of experts per MoE layer.
        max_experts_per_server (int): Maximum number of experts per server.
        max_layers_per_server (int): Maximum number of layers per server.
        max_layer_experts_per_server (int): Maximum number of experts per layer per server.
        random_seed (Optional[int]): Random seed for reproducibility.

    Returns:
        Tuple[
            List[Dict],                # expert_placements: List of expert placement dicts
            List[Dict],                # layer_placements: List of layer placement dicts
            Dict[int, int],            # server_expert_count: Number of experts per server
            Dict[int, int],            # server_layer_count: Number of layers per server
        ]
    """
    num_servers = distance_matrix.shape[0]
    
    if random_seed is not None:
        np.random.seed(random_seed)
    
    # Initialize placement data structures
    expert_placements = []
    layer_placements = []
    server_expert_count = {i: 0 for i in range(num_servers)}
    server_layer_count = {i: 0 for i in range(num_servers)}
    
    # # Phase 1: Place attention layers strategically
    # attention_servers = []
    
    # for layer_id in range(num_layers):
    #     layer_type = 'moe' if layer_id % 2 == 1 else 'attention'
        
    #     if layer_type == 'attention':
    #         # Spread attention layers across the network for better load balancing
    #         server_id = layer_id//2 % num_servers
    #         attention_servers.append(server_id)
            
    #         layer_placements.append({
    #             'layer_id': layer_id,
    #             'layer_type': layer_type,
    #             'server_id': server_id
    #         })
            
    #         server_layer_count[server_id] += 1
    
    # stride = num_servers//(num_layers//2)
    stride = num_servers/(num_layers//2)
    attention_servers, server_layer_count = _place_attentions_round_robin(num_servers, num_layers, max_layers_per_server, stride)
    for layer_id in range(0, num_layers, 2):
        attn_idx = layer_id//2
        server_id = attention_servers[attn_idx]
        layer_placements.append({
            'layer_id': layer_id,
            'layer_type': 'attention',
            'server_id': server_id
        })
    
    # Phase 2: Place MoE layers with round robin distribution
    attention_idx = 0
    
    for layer_id in range(num_layers):
        layer_type = 'moe' if layer_id % 2 == 1 else 'attention'
        
        server_expert_from_layer_count = {server_id: 0 for server_id in range(num_servers)}
        
        if layer_type == 'moe':
            attention_idx = layer_id//2
            # Get dispatch and collect servers from adjacent attention layers
            dispatch_server = attention_servers[attention_idx] if attention_idx < len(attention_servers) else attention_servers[0]
            collect_server = attention_servers[attention_idx + 1] if attention_idx + 1 < len(attention_servers) else attention_servers[-1]
            
            # Place experts using round robin distribution to ensure equal distribution
            # This ensures no server gets too many experts from the same layer
            layer_experts = []
            for expert_id in range(experts_per_layer):
                # Round robin: distribute experts evenly across all servers
                # server_id = (expert_id) % num_servers
                start_serve_offset = int((experts_per_layer/max_layer_experts_per_server)/2)
                server_id = ((dispatch_server+collect_server)//2 - start_serve_offset + expert_id//max_layer_experts_per_server) % num_servers
                
                while server_expert_from_layer_count[server_id] >= max_layer_experts_per_server or server_expert_count[server_id] >= max_experts_per_server:
                    server_id = (server_id + 1) % num_servers
                
                # print("layer_id", layer_id, "expert_id", expert_id, "server_id", server_id)
                expert_placements.append({
                    'expert_id': expert_id,
                    'layer_id': layer_id,
                    'server_id': server_id
                })
                
                layer_experts.append(expert_id)
                server_expert_count[server_id] += 1
                server_expert_from_layer_count[server_id] += 1
            
            layer_placements.append({
                'layer_id': layer_id,
                'layer_type': layer_type,
                'server_id': dispatch_server,
                'experts': layer_experts,
                'dispatch_server': dispatch_server,
                'collect_server': collect_server
            })
            
            attention_idx += 1
    
    return expert_placements, layer_placements, server_expert_count, server_layer_count


# EVOLVE-BLOCK-END
