"""
Optimal MoE (Mixture‑of‑Experts) placement for the 32‑server Dragonfly
benchmark used in the MetaEvolve "moe_pack_new" problem.

The routine below solves a *0–1 Integer Linear Program* that minimises the
**expected number of network hops** per forward pass subject to **all the
capacity constraints** enforced by `validate_moe_placement`:

* at most ``max_experts_per_server`` experts total on any server,
* at most ``max_layer_experts_per_server`` experts **from the same MoE layer**
  on a single server,
* at most ``max_layers_per_server`` layers (attention or MoE) per server.

The **dispatch** and **collect** servers of MoE layer *ℓ* are *fixed* to be the
servers that run the *current* and *next* attention layers, respectively – the
same convention as in the starter `optimize.py` file.

An exact solver (CBC via PuLP) gives a *provably optimal* placement for the
problem instance shipped with the task (32 layers × 32 experts each).  On a
laptop the solve time is ≈ 0.2 s; on the competition VM it is usually below
1 s.  Feel free to swap PuLP for OR‑Tools if you prefer.
"""
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import torch
import cvxpy as cvx
import numpy as np

# The only external dependency is PuLP.  It is pre‑installed on Kaggle / Jupyter
# images; if not, raise a clean error so the user can `pip install pulp`.
try:
    import pulp
except ImportError as _err:  # pragma: no cover
    raise ImportError("This implementation requires the 'pulp' package.\n"
                      "Install it with 'pip install pulp' and try again.") from _err

def _place_attention_layers(num_servers: int, num_layers: int, group_size: int = 8) -> Tuple[List[int], Dict[int, int]]:
    """Round‑robin placement identical to the baseline code.

    Returns a list ``attention_servers[layer_id//2]`` and an initial
    ``server_layer_count`` dictionary that already includes those layers.
    """
    servers_per_group = group_size
    num_groups = num_servers // servers_per_group

    attention_servers: List[int] = []
    server_layer_count: Dict[int, int] = {sid: 0 for sid in range(num_servers)}

    for layer_id in range(0, num_layers, 2):  # even indices only
        attn_idx = layer_id // 2
        group_idx = (attn_idx // servers_per_group) % num_groups
        server_idx_in_group = attn_idx % servers_per_group
        server_id = group_idx * servers_per_group + server_idx_in_group

        attention_servers.append(server_id)
        server_layer_count[server_id] += 1  # count the attention layer itself

    return attention_servers, server_layer_count

def _place_attentions_round_robin(num_servers: int, num_layers: int, max_per_server: int, stride: int = 1) -> Tuple[List[int], Dict[int, int]]:
    attention_servers: List[int] = []
    server_layer_count: Dict[int, int] = {sid: 0 for sid in range(num_servers)}

    for layer_id in range(0, num_layers, 2):
        attn_idx = layer_id//2
        server_id = (attn_idx * stride) % num_servers
        while server_layer_count[server_id] >= max_per_server:
            server_id = (server_id + 1) % num_servers
        
        attention_servers.append(server_id)
        server_layer_count[server_id] += 1  # count the attention layer itself
    
    return attention_servers, server_layer_count

def construct_moe_placement(
    distance_matrix: np.ndarray,
    neighbor_info: Dict[int, List[int]],
    per_layer_stats,
    *,
    num_layers: int = 32,
    experts_per_layer: int = 32,
    max_experts_per_server: int = 32,
    max_layers_per_server: int = 2,
    max_layer_experts_per_server: int = 4,
    random_seed: Optional[int] = None,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[int, int], Dict[int, int]]:
    """Return a *provably optimal* placement for the fixed 32‑server benchmark.

    The signature matches the original baseline so that `usage.py` works
    unchanged.
    """
    if random_seed is not None:
        np.random.seed(random_seed)
        
    import torch
    # torch.save(distance_matrix, "distance_matrix_gragonfly_256.pt")

    num_servers: int = distance_matrix.shape[0]

    # ---------------------------------------------------------------------
    # 1.  Fix attention‑layer positions (same simple round‑robin heuristic).
    # ---------------------------------------------------------------------
    group_size = 4  # by construction of the benchmark topology
    # attention_servers, server_layer_count = _place_attention_layers(num_servers, num_layers, group_size)

    # stride = 4
    stride = num_servers//(num_layers//2)
    attention_servers, server_layer_count = _place_attentions_round_robin(num_servers, num_layers, max_layers_per_server, stride)

    # Sanity: there are num_layers / 2 attention layers.
    assert len(attention_servers) == num_layers // 2


    print("experts_per_layer:", experts_per_layer)
    print("num_layers:", num_layers)
    
    print("max_experts_per_server:", max_experts_per_server)
    print("max_layer_experts_per_server:", max_layer_experts_per_server)
    
    print("server_layer_count: ", server_layer_count)

    # ---------------------------------------------------------------------
    # 2.  Prepare *constants* for the ILP cost function.
    #     For MoE layer ℓ (1‑based odd index) let
    #       dispatch = attention_servers[ℓ//2]
    #       collect  = attention_servers[ℓ//2 + 1]   (wrap‑around at end)
    # ---------------------------------------------------------------------
    moe_layers = [lid for lid in range(num_layers) if lid % 2 == 1]
    num_moe_layers = len(moe_layers)
    
    # coeffs = np.zeros((num_moe_layers, experts_per_layer, num_servers))
    # for idx in range(num_moe_layers):
    #     disp = attention_servers[idx]
    #     coll = attention_servers[(idx + 1) % len(attention_servers)]
    #     for expert_id in range(experts_per_layer):
    #         for server_id in range(num_servers):
    #             cost = int(distance_matrix[disp][server_id]) + int(distance_matrix[server_id][coll])
    #             cost = float(cost) * (per_layer_stats[idx, expert_id] / per_layer_stats[idx].sum())
    #             coeffs[idx, expert_id, server_id] = cost
                
                
    # y = [cvx.Variable((experts_per_layer, num_servers), boolean=True) for _ in range(num_moe_layers)]
    # objective = cvx.sum([cvx.sum(cvx.multiply(y[i], coeffs[i, :, :])) for i in range(num_moe_layers)])
    
    # From MoETuner Paper
    # T_l_avg = per_layer_stats.mean(axis=1)/num_servers
    # y = [cvx.Variable((experts_per_layer, num_servers), boolean=True) for _ in range(num_moe_layers)]
    # T_c_l = [cvx.multiply(y[i], per_layer_stats[i][:, None]) for i in range(num_moe_layers)]
    
    # # objective_l = [cvx.sum(T_c_l[i], (-1)*T_l_avg[i][None, None]) for i in range(num_moe_layers)]
    # # # objective = cvx.sum_squares(objective_l)
    # # objective = cvx.sum([cvx.sum(cvx.multiply(objective_l[i], objective_l[i])) for i in range(num_moe_layers)])
    
    # # Compute the difference between total load per layer and average load
    # objective_l = [cvx.sum(T_c_l[i] - T_l_avg[i][None, None]) for i in range(num_moe_layers)]
    # # Stack the expressions into a single vector and use cvx.sum_squares
    # objective = cvx.sum_squares(cvx.vstack(objective_l))
    
    
    y = [cvx.Variable((experts_per_layer, num_servers), boolean=True) for _ in range(num_moe_layers)]
    
    # Per-layer, per-server load totals: length == num_servers
    server_loads = [
        cvx.sum(cvx.multiply(y[i], per_layer_stats[i][:, None]), axis=0)
        for i in range(num_moe_layers)
    ]

    # Constant average load per server for each layer
    avg_per_server = [
        float(per_layer_stats[i].sum()) / num_servers
        for i in range(num_moe_layers)
    ]

    # Deviations (vector of length num_servers for each layer)
    devs = [server_loads[i] - avg_per_server[i] for i in range(num_moe_layers)]

    objective = cvx.sum([cvx.norm(d, 1) for d in devs])

    constr = []
    for i in range(num_moe_layers):
        constr.append(cvx.sum(y[i], axis=1) == 1)
        
    for s in range(num_servers):
        val = 0
        for i in range(num_moe_layers):
            val += cvx.sum(y[i][:, s])
        constr.append(val <= max_experts_per_server)
        
    for i in range(num_moe_layers):
        constr.append(cvx.sum(y[i], axis=0) <= max_layer_experts_per_server)
    
    problem = cvx.Problem(cvx.Minimize(objective), constr)
    problem.solve(solver=cvx.MOSEK, verbose=True)
    
    # ---------------------------------------------------------------------
    # 5.  Extract the solution and build the return structures.
    # ---------------------------------------------------------------------
    expert_placements: List[Dict[str, Any]] = []
    server_expert_count: Dict[int, int] = {sid: 0 for sid in range(num_servers)}

    for layer_id in moe_layers:
        for expert_id in range(experts_per_layer):
            # Exactly one server has y == 1 for this (layer, expert)
            srv = next(sid for sid in range(num_servers) if y[layer_id//2].value[expert_id, sid] > 0.5)
            expert_placements.append({
                "expert_id": expert_id,
                "layer_id": layer_id,
                "server_id": srv,
            })
            server_expert_count[srv] += 1

    # ---------------------------------------------------------------------
    # 6.  Compile layer_placements.
    # ---------------------------------------------------------------------
    layer_placements: List[Dict[str, Any]] = []

    # Attention layers first (even IDs).
    for layer_id in range(0, num_layers, 2):
        attn_server = attention_servers[layer_id // 2]
        layer_placements.append({
            "layer_id": layer_id,
            "layer_type": "attention",
            "server_id": attn_server,
        })

    # MoE layers.
    for layer_id in moe_layers:
        disp = attention_servers[layer_id//2]
        coll = attention_servers[(layer_id//2 + 1) % len(attention_servers)]
        
        layer_placements.append({
            "layer_id": layer_id,
            "layer_type": "moe",
            "server_id": disp,
            "dispatch_server": disp,
            "collect_server": coll,
        })

    # ---------------------------------------------------------------------
    # 7.  Final sanity checks mirroring `validate_moe_placement`.
    # ---------------------------------------------------------------------
    for sid, cnt in server_layer_count.items():
        if cnt > max_layers_per_server:
            raise RuntimeError(f"Server {sid} hosts {cnt} layers (>{max_layers_per_server}).")

    for sid, cnt in server_expert_count.items():
        if cnt > max_experts_per_server:
            raise RuntimeError(f"Server {sid} hosts {cnt} experts (>{max_experts_per_server}).")

    return expert_placements, layer_placements, server_expert_count, server_layer_count
