import torch as th

def message_design_instruction():
    """
    Improved Message Design Instruction for 5z_vs_1ul (SMAC):

    1. **Purpose**:
       Address the previous protocol's weaknesses by enabling agents to uniformly infer global team configuration and critical tactical cues—especially absolute positions, health, shield, weapon cooldown, and recent movement intent—under persistent partial observability.

    2. **Content Selection & Justification**:
       Each agent broadcasts a message constructed as follows, leveraging a short temporal window (last 10 steps):

       [A] **Own Absolute Position History (last 10 steps):**
           - Normalized absolute X and Y positions at each step (2 × 10 = 20 dims).
           - *Why*: Directly provides the agent's recent trajectory, enabling precise team formation inference and movement intent prediction, which are not inferable from relative/partial local observations alone.

       [B] **Own Health & Shield Ratio History (last 10 steps):**
           - Health ratio (1 × 10 = 10 dims)
           - Shield ratio (1 × 10 = 10 dims)
           - *Why*: Shares critical survivability and shield-break info over time, allowing teammates to coordinate protection and focus fire more reliably.

       [C] **Own Weapon Cooldown History (last 10 steps):**
           - Weapon cooldown status (1 × 10 = 10 dims)
           - *Why*: Essential for synchronizing focus fire and coordinated attack/retreat cycles.

       [D] **Current Step Ultralisk Observation Flag:**
           - 1 dim: Whether this agent currently sees the Ultralisk (from o[..., 4])
           - *Why*: Allows all agents to reason about who has the most up-to-date enemy info.

       [E] **Sender Identity:**
           - One-hot vector (5 dims)
           - *Why*: Allows receivers to attribute state and trajectory info to the correct teammate.

       **Total message_dim = 20 + 10 + 10 + 10 + 1 + 5 = 56**

    3. **How fields are extracted:**
       - Absolute position: For each timestep in the last 10 steps, use agent's own relative X/Y to map center (if available; else zero).
         - In SMAC, this is not directly in the observation, but can be reconstructed as the agent's own relative X/Y to the map center, typically found in the observation dimensions corresponding to the agent's self.
         - Here, we use dims 12,13,18,19,24,25,30,31 (relative positions to allies) and dims 5,6 (relative to Ultralisk) **do NOT** provide absolute position. However, in SMAC, agent's own absolute (normalized) X/Y are usually encoded as the relative X/Y to map center in the first two dims of the agent's own observation. If not, these should be appended to the observation as a preprocessing step. For this design, we will assume access to own absolute X/Y for each agent per step, or that they are provided as the first two dims.
         - If not available, this field should be replaced by the relative positions to all visible allies (and zeros if not visible), but this is less robust.

       - Health/shield/cooldown: Use dims 34 (health), 35 (shield), 36 (cooldown) per timestep for the last 10 steps.
       - Weapon cooldown: In SMAC, cooldown is typically in the observation; if not, omit or set to zero.
         - For this design, we assume dim 36 is agent's own cooldown (if not, set to zero).

    4. **Communication Protocol:**
       - **Broadcast**: Each agent sends its message to all others.
       - Each agent receives the 4 messages from teammates (not itself), for a total incoming message size of 4 × 56 = 224.
       - Each received message is a concatenation of [A,B,C,D,E] from each teammate.
       - On input, the final tensor is (batch, 5, 48+224).

    5. **Why these fields?**
       - Absolute position, health, shield, and cooldown history are not inferable by others under partial observability, but are critical for coordinated kiting, focus fire, and synchronized retreat/advance.
       - Sharing a short history allows for intent/pattern prediction (e.g., which agent is moving to flank or kite).
       - The Ultralisk visibility flag helps resolve which teammates have the freshest enemy info.
       - Sender identity ensures explicit message attribution.

    6. **Efficiency:**
       - All operations are batch- and vectorized; no explicit for-loops over batch or agent dimension.
       - No trainable parameters.
    """
    return (
        "Each agent's message consists of: "
        "1) Its own normalized absolute X and Y position history for the last 10 steps (20 dims), "
        "2) Its own health ratio history (10 dims), "
        "3) Its own shield ratio history (10 dims), "
        "4) Its own weapon cooldown status history (10 dims), "
        "5) Its current Ultralisk visibility flag (1 dim), "
        "6) Its sender identity (one-hot, 5 dims). "
        "Each agent receives the messages from the other 4 agents (not itself), concatenating these 4 messages to its own observation. "
        "Total added message_dim = 56*4 = 224. "
        "This protocol enables robust, synchronized tactical execution by letting all agents reconstruct the full team state—including absolute positions, health, shield, weapon readiness, and intent—under partial observability, as well as reason about teammates' recent behavioral patterns."
    )

def communication(o):
    """
    Input:
        o: torch.Tensor, shape (batch, T, 5, 48), agent observations for last 10 steps (T >= 10).
    Output:
        messages_o: torch.Tensor, shape (batch, 5, 48+224), concatenated current observation and received messages.
    """
    # Device and dtype preservation
    device = o.device
    dtype = o.dtype
    batch_size, T, n_agents, obs_dim = o.shape
    assert n_agents == 5 and obs_dim == 48
    assert T >= 10, "Need at least 10 steps to extract temporal context"

    # Use last 10 steps (including current)
    o_last10 = o[:, -10:, :, :]  # (batch, 10, 5, 48)

    # 1. Absolute X/Y position: 
    # If your system stores absolute positions elsewhere, replace the following lines accordingly.
    # For demonstration, let's assume dims 0 and 1 of each agent's own observation encode normalized absolute X/Y.
    # If not, these fields should be provided in the environment preprocessing.
    abs_x = o_last10[..., 0]  # (batch, 10, 5)
    abs_y = o_last10[..., 1]  # (batch, 10, 5)

    # 2. Health ratio (dim 34), 3. Shield ratio (dim 35), 4. Weapon cooldown (dim 36)
    health = o_last10[..., 34]  # (batch, 10, 5)
    shield = o_last10[..., 35]  # (batch, 10, 5)
    # Weapon cooldown: if not available, set to zeros
    if obs_dim > 36:
        cooldown = o_last10[..., 36]  # (batch, 10, 5)
    else:
        cooldown = th.zeros_like(health)

    # 5. Current Ultralisk visibility (from current step only, dim 4)
    ultra_vis_cur = o[:, -1, :, 4]  # (batch, 5)

    # 6. Sender identity (one-hot, dims 43:48)
    sender_id = o[:, -1, :, 43:48]  # (batch, 5, 5)

    # Reshape histories: (batch, 5, 10)
    abs_x = abs_x.permute(0, 2, 1)  # (batch, 5, 10)
    abs_y = abs_y.permute(0, 2, 1)
    health = health.permute(0, 2, 1)
    shield = shield.permute(0, 2, 1)
    cooldown = cooldown.permute(0, 2, 1)

    # Stack history vectors per agent: (batch, 5, 20+10+10+10) = (batch, 5, 50)
    hist_vec = th.cat([
        abs_x,        # 10
        abs_y,        # 10
        health,       # 10
        shield,       # 10
        cooldown      # 10
    ], dim=-1)  # (batch, 5, 50)

    # Add current Ultralisk visibility and sender identity
    ultra_vis_cur = ultra_vis_cur.unsqueeze(-1)  # (batch, 5, 1)
    sender_id = sender_id  # (batch, 5, 5)

    msg = th.cat([hist_vec, ultra_vis_cur, sender_id], dim=-1)  # (batch, 5, 56)

    # For each agent, collect messages from other 4 agents (not itself)
    # Create a mask to exclude self-messages
    mask = (1 - th.eye(n_agents, device=device, dtype=dtype)).unsqueeze(0).unsqueeze(-1)  # (1, 5, 5, 1)

    msg_exp = msg.unsqueeze(1).expand(batch_size, n_agents, n_agents, 56)  # (batch, 5, 5, 56)
    msg_masked = msg_exp * mask  # self-message is zeroed

    # For each agent, get indices of other agents
    indices = []
    for i in range(n_agents):
        indices.append([j for j in range(n_agents) if j != i])
    indices = th.tensor(indices, device=device)  # (5, 4)

    indices_expand = indices.unsqueeze(0).expand(batch_size, -1, -1)  # (batch, 5, 4)
    indices_expand = indices_expand.unsqueeze(-1).expand(-1, -1, -1, 56)  # (batch, 5, 4, 56)
    msgs_from_others = th.gather(msg_masked, 2, indices_expand)  # (batch, 5, 4, 56)

    # Reshape to (batch, 5, 224)
    msgs_from_others = msgs_from_others.reshape(batch_size, n_agents, 4 * 56)

    # Get the current observation for each agent: o[:, -1, :, :] -> (batch, 5, 48)
    o_cur = o[:, -1, :, :]

    # Concatenate with current observation
    messages_o = th.cat([o_cur, msgs_from_others], dim=-1)  # (batch, 5, 48+224)

    return messages_o
