import torch as th

def message_design_instruction():
    """
    Message Design Instruction:

    - **Sender**: Only Overseers (agents 20 and 21) send messages. Banelings (agents 0-19) remain silent.
    - **Message Content (per Overseer)**:
        - For each visible Roach (Enemy0, Enemy1):
            - Visibility flag (1/0): o[..., 4], o[..., 12]
            - Relative X position: o[..., 6], o[..., 14]
            - Relative Y position: o[..., 7], o[..., 15]
        - **Movement possibility (intent)**: o[..., 0:4] (can move North, South, East, West)
        - **Last action (one-hot)**: o[..., 170:178] (no-op, stop, move N/S/E/W, attack enemy0, attack enemy1)
        - Sender identity: 22-dimensional one-hot vector (o[..., 178:200])
    - **Message Structure (per Overseer)**:
        [E0_visible, E0_relX, E0_relY, E1_visible, E1_relX, E1_relY, 
         Move_N, Move_S, Move_E, Move_W, 
         Last_action_noop, stop, N, S, E, W, atk0, atk1, 
         SenderID (22 dims)]
        - Total: 6 + 4 + 8 + 22 = 40 dimensions per Overseer.
    - **Communication Protocol**:
        - **Broadcast**: Each Overseer broadcasts its message to all Banelings and to the other Overseer (not to itself).
        - Each Baneling receives both Overseers' messages (concatenated: 80 dims).
        - Each Overseer receives only the other Overseer's message (40 dims, zero-padded to the right for consistency).
        - All messages include sender identity for explicit grounding.
    - **Rationale**:
        - Banelings cannot observe Roaches or Overseers directly. By receiving Overseer messages, they gain access to critical Roach positions and Overseer behavioral context (movement intent, last action), enabling more robust prediction and coordination even under dynamic or ambiguous conditions.
        - The new behavioral cues provide additional temporal and intent context, improving coordination and robustness, especially when environmental cues are sparse or noisy.
        - Sender identity ensures messages are interpretable and correctly attributed.
    """
    return (
        "Message Design:\n"
        "- Only Overseers (agents 20 and 21) send messages. Each Overseer broadcasts a message containing, for each visible Roach (Enemy0, Enemy1): visibility flag (1/0), relative X position, and relative Y position (6 values total), "
        "their own movement possibility (can move N/S/E/W; 4 values), their last action (one-hot across 8 values), and a 22-dimensional one-hot sender identity. "
        "Each message is thus 40 dimensions. Every Baneling receives both Overseers' messages (concatenated: 80 dims), while each Overseer receives the other Overseer's message (40 dims, zero-padded for consistency). "
        "Sender identity is always included for explicit grounding. This ensures Banelings receive uniquely held, actionable information needed for task success, including both Roach positions and Overseer behavioral context, while avoiding redundancy."
    )

def communication(o):
    """
    o: torch.Tensor of shape (batch, 22, 200)
    Returns: torch.Tensor of shape (batch, 22, 200+message_dim)
    """
    # Agent indices
    overseer_ids = [20, 21]
    baneling_ids = list(range(20))  # agents 0-19

    batch_size = o.shape[0]
    device = o.device

    # Extract Overseer observations: shape (batch, 2, 200)
    overseer_obs = o[:, overseer_ids, :]

    # Roach info
    enemy0_fields = overseer_obs[:, :, [4, 6, 7]]   # (batch, 2, 3)
    enemy1_fields = overseer_obs[:, :, [12, 14, 15]]  # (batch, 2, 3)

    # Movement possibility
    move_fields = overseer_obs[:, :, 0:4]  # (batch, 2, 4)

    # Last action
    last_action_fields = overseer_obs[:, :, 170:178]  # (batch, 2, 8)

    # SenderID
    sender_id_fields = overseer_obs[:, :, 178:200]  # (batch, 2, 22)

    # Compose message: (batch, 2, 40)
    msg = th.cat(
        [enemy0_fields, enemy1_fields, move_fields, last_action_fields, sender_id_fields], 
        dim=-1
    )  # (batch, 2, 40)

    # Prepare empty message tensor for all agents
    message_dim = 80
    messages = th.zeros((batch_size, 22, message_dim), dtype=o.dtype, device=device)

    # For Banelings: receive both Overseers' messages (concat)
    both_msgs = msg.reshape(batch_size, -1)  # (batch, 80)
    messages[:, baneling_ids, :] = both_msgs.unsqueeze(1).expand(-1, len(baneling_ids), -1)

    # For Overseers: receive only the other Overseer's message (pad to 80)
    # For agent 20, receive agent 21's msg
    messages[:, 20, :40] = msg[:, 1, :]
    # For agent 21, receive agent 20's msg
    messages[:, 21, :40] = msg[:, 0, :]

    # Concatenate messages to observations
    messages_o = th.cat([o, messages], dim=-1)  # (batch, 22, 200+80)
    return messages_o
