import torch as th

def message_design_instruction():
    """
    Improved Message Design Instruction for 5z_vs_1ul (SMAC):

    1. **Purpose**: 
       This protocol enables full-team coordination for kiting, focus fire, and survivability by explicitly sharing each agent's absolute position, health, and intent, as well as opportunistically sharing enemy state if observed. The design directly addresses the inability to reconstruct global state and behavioral cues in the previous protocol.

    2. **Message Content & Justification**:
       **Each agent broadcasts a message containing:**

       [A] **Own Absolute Position** (2 floats):  
           - Absolute X, Y coordinates derived from observed relative positions of all visible allies and self.  
           - *Rationale*: Allows agents to reconstruct the team's formation, support kiting, and synchronize movement even under partial observability.

       [B] **Own Health Ratio** (1 float) and **Own Shield Ratio** (1 float):  
           - *Rationale*: Enables the team to prioritize protection and focus fire, and to identify which agents are vulnerable.

       [C] **Own Last Action** (7 floats, one-hot):  
           - *Rationale*: Shares behavioral intent for synchronized movement/attack.

       [D] **Own "Ready to Engage" Signal** (1 float, binary):  
           - 1 if agent is in position to attack (can see Ultralisk and is not moving), else 0.  
           - *Rationale*: Facilitates precise timing for coordinated attacks or retreats.

       [E] **Ultralisk Observation** (5 floats):  
           - If Ultralisk visible: absolute X, Y (calculated via own absolute position + relative), health ratio, shield ratio, and 'enemy_seen' (1).  
           - Else: zeros.  
           - *Rationale*: Enables global enemy state reconstruction, supporting focus fire and joint kiting.

       [F] **Sender Identity** (5 floats, one-hot):  
           - *Rationale*: For explicit sender-receiver grounding.

       **Total message_dim = 2 (own XY) + 2 (health/shield) + 7 (last action) + 1 (ready) + 5 (Ultralisk obs) + 5 (sender ID) = 22**

    3. **Communication Protocol**:
       - **Broadcast**: Each agent sends its message to all others (not to itself).
       - Each agent receives messages from the other 4 agents (not itself), resulting in 4 × 22 = 88 message dimensions appended to its own observation.
       - Messages are concatenated in sender order for each receiver.

    4. **Explicitness & Compactness**:
       - Each field has a fixed semantic order.
       - Only shares information that cannot be reliably inferred from local observations.
       - Excludes redundant fields present in previous design.

    5. **Why these fields?**
       - **Absolute positions** are necessary for global team reasoning and cannot be inferred from relative-only local observations.
       - **Status and intent** support dynamic, context-aware micro-management and coordination.
       - **Ultralisk absolute state** is only available when seen, and is essential for focus fire and kiting.
       - **Explicit "ready" cue** enables precise synchronization for engagement.

    6. **Efficiency**:
       - All operations are batch- and vectorized; no explicit for-loops over batch or agent dimension.
       - No trainable parameters.

    7. **Resulting Output**:
       - Output tensor shape: (batch, 5, 48 + 88).

    """
    return (
        "Each agent broadcasts a message composed of: "
        "[Own absolute X, Own absolute Y, Own health ratio, Own shield ratio, "
        "Own last action (one-hot, 7), Ready-to-engage flag (1 if can attack Ultralisk and not moving, else 0), "
        "Ultralisk absolute X, Ultralisk absolute Y, Ultralisk health ratio, Ultralisk shield ratio, Ultralisk seen (1/0), "
        "Sender identity (one-hot, 5)]. "
        "If Ultralisk is not visible, enemy fields are zeroed. "
        "Each agent receives the 4 messages from teammates (not including itself), concatenating them to its own observation. "
        "Total added message_dim = 22*4 = 88. "
        "This protocol addresses previous weaknesses by allowing all agents to reconstruct the global state (team and enemy locations/status), "
        "and to coordinate using explicit intent/engagement cues, all in a compact and non-redundant structure."
    )

def communication(o):
    """
    Input:
        o: torch.Tensor, shape (batch, 5, 48), original agent observations.

    Output:
        messages_o: torch.Tensor, shape (batch, 5, 48+88), concatenated observation and received messages.
    """
    device = o.device
    dtype = o.dtype
    B, N, D = o.shape
    assert N == 5 and D == 48

    # -- 1. Compute own absolute position (X, Y) for each agent --
    # For each agent, we reconstruct own absolute position by referencing a visible ally (preferably the lowest-index visible ally).
    # If no ally is visible, use (0,0).
    # For agent i, fields for ally j (j != i) are:
    #   - Is ally visible: 10 + 6*(j-1) if j < i, 10 + 6*j if j > i
    #   - Rel X to ally:   12 + 6*(j-1) if j < i, 12 + 6*j if j > i
    #   - Rel Y to ally:   13 + 6*(j-1) if j < i, 13 + 6*j if j > i

    own_abs_xy = th.zeros((B, N, 2), device=device, dtype=dtype)
    for ref in range(N):
        # skip own index
        visible_mask = []
        rel_x = []
        rel_y = []
        for i in range(N):
            if i == ref:
                # skip (cannot see self)
                visible_mask.append(th.zeros(B, 1, device=device, dtype=dtype))
                rel_x.append(th.zeros(B, 1, device=device, dtype=dtype))
                rel_y.append(th.zeros(B, 1, device=device, dtype=dtype))
            else:
                if ref < i:
                    start = 10 + 6*(i-1)
                else:
                    start = 10 + 6*i
                visible_mask.append(o[:, ref, start:start+1])
                rel_x.append(o[:, ref, start+2:start+3])
                rel_y.append(o[:, ref, start+3:start+4])
        # For agent 'ref', stack visibility and relative positions to all other agents
        visible_mask_stack = th.cat(visible_mask, dim=-1)  # (B, N)
        rel_x_stack = th.cat(rel_x, dim=-1)                # (B, N)
        rel_y_stack = th.cat(rel_y, dim=-1)                # (B, N)

        # For agent 'ref', find first visible ally (excluding self)
        # Mask out self
        visible_mask_stack[:, ref] = 0
        # Find first visible agent (lowest index)
        has_visible = visible_mask_stack.sum(dim=-1, keepdim=True) > 0  # (B, 1)
        first_visible_idx = visible_mask_stack.float().argmax(dim=-1)   # (B,)
        # Gather rel_x/y for first visible
        rel_x_first = rel_x_stack[th.arange(B), first_visible_idx]      # (B,)
        rel_y_first = rel_y_stack[th.arange(B), first_visible_idx]      # (B,)
        # For batch elements with no visible allies, use (0,0)
        rel_x_first = rel_x_first * has_visible.squeeze(-1)
        rel_y_first = rel_y_first * has_visible.squeeze(-1)
        # For own absolute position, assume the visible ally is at (0,0), so own abs pos = -rel
        own_abs_xy[:, ref, 0] = -rel_x_first
        own_abs_xy[:, ref, 1] = -rel_y_first
        # If no visible ally, remain (0,0)
    # own_abs_xy: (B, N, 2)

    # -- 2. Own health & shield ratios --
    own_health = o[..., 34:35]  # (B, N, 1)
    own_shield = o[..., 35:36]  # (B, N, 1)

    # -- 3. Own last action one-hot (7 dims) --
    last_action = o[..., 36:43]  # (B, N, 7)

    # -- 4. Ready-to-engage flag --
    # 1 if agent can see Ultralisk and is not moving (i.e., last action is Attack or Stop)
    ultra_visible = o[..., 4:5]  # (B, N, 1)
    last_attack = o[..., 42:43]  # Last Action-Attack (B, N, 1)
    last_stop = o[..., 37:38]    # Last Action-Stop
    ready_flag = ultra_visible * ((last_attack + last_stop) > 0).float()  # (B, N, 1)

    # -- 5. Ultralisk absolute X, Y, health, shield, seen --
    # If Ultralisk is visible, Ultralisk abs pos = own abs pos + rel (from dims 5,6)
    rel_ultra_x = o[..., 5:6]  # (B, N, 1)
    rel_ultra_y = o[..., 6:7]
    ultra_health = o[..., 8:9]
    ultra_shield = o[..., 9:10]
    ultra_seen_flag = ultra_visible  # (B, N, 1)
    ultra_abs_x = own_abs_xy[..., 0:1] + rel_ultra_x * ultra_visible
    ultra_abs_y = own_abs_xy[..., 1:2] + rel_ultra_y * ultra_visible
    # Zero out all if not visible
    ultra_abs_x = ultra_abs_x * ultra_visible
    ultra_abs_y = ultra_abs_y * ultra_visible
    ultra_health = ultra_health * ultra_visible
    ultra_shield = ultra_shield * ultra_visible
    ultra_seen_flag = ultra_seen_flag

    ultra_fields = th.cat([ultra_abs_x, ultra_abs_y, ultra_health, ultra_shield, ultra_seen_flag], dim=-1)  # (B, N, 5)

    # -- 6. Sender identity: one-hot (B, N, 5) --
    sender_id = o[..., 43:48]

    # -- 7. Stack message fields --
    msg_fields = [
        own_abs_xy,           # (B, N, 2)
        own_health,           # (B, N, 1)
        own_shield,           # (B, N, 1)
        last_action,          # (B, N, 7)
        ready_flag,           # (B, N, 1)
        ultra_fields,         # (B, N, 5)
        sender_id             # (B, N, 5)
    ]
    msg = th.cat(msg_fields, dim=-1)  # (B, N, 22)

    # -- 8. For each agent, collect messages from the other 4 agents (not self) --
    mask = (1 - th.eye(N, device=device, dtype=dtype)).unsqueeze(0).unsqueeze(-1)  # (1, N, N, 1)
    msg_exp = msg.unsqueeze(1).expand(B, N, N, 22)  # (B, N, N, 22)
    msg_masked = msg_exp * mask  # self-message is zeroed

    # For each receiver (agent axis=1), extract the messages from other 4 agents (axis=2)
    indices = []
    for i in range(N):
        indices.append([j for j in range(N) if j != i])
    indices = th.tensor(indices, device=device)  # (N, 4)
    indices_expand = indices.unsqueeze(0).expand(B, -1, -1)  # (B, N, 4)
    indices_expand = indices_expand.unsqueeze(-1).expand(-1, -1, -1, 22)  # (B, N, 4, 22)
    msgs_from_others = th.gather(msg_masked, 2, indices_expand)  # (B, N, 4, 22)
    msgs_from_others = msgs_from_others.reshape(B, N, 4*22)  # (B, N, 88)

    # -- 9. Concatenate with original observation --
    messages_o = th.cat([o, msgs_from_others], dim=-1)  # (B, N, 48+88)

    return messages_o
