import torch as th

def message_design_instruction():
    """
    Improved Message Design Instruction for LBF (grid obs):

    Purpose:
    - Enable agents to coordinate and avoid redundant behavior by sharing **absolute position** and **intended next move**.
    - This information is not available to other agents from local observations and is critical for inferring global state, planning, and preventing collisions.

    What to Communicate:
    - **Absolute Position**: 2 floats (x, y), normalized to [0, 1] based on grid size. This allows agents to localize each other and plan efficiently.
    - **Intended Next Move**: 5-dim one-hot vector (No-op, North, South, West, East) indicating the agent's planned direction.
    - **Sender ID**: 6-dim one-hot vector for sender attribution.
    - **Carrying Food Flag** (optional): 1-dim binary indicating if the agent is currently carrying food or has just picked up food (if this info is available in the observation).

    Communication Protocol:
    - **Broadcast**: Each agent broadcasts its message to all others.
    - Each agent receives messages from all other agents (excluding itself), ordered by sender ID.
    - For each agent, the 5 incoming messages are concatenated and flattened (5*14 = 70 dims).
    - The 70-dim message vector is concatenated to the original 39-dim observation, yielding a final shape of (batch, 6, 109).

    Rationale:
    - **Absolute position** enables agents to localize others, infer global food/agent distribution, and avoid collisions.
    - **Intended move** provides actionable coordination to prevent movement conflicts.
    - **Sender ID** ensures unambiguous attribution.
    - **Carrying food** (optional) supports division of labor and path planning.

    Efficiency:
    - All operations are vectorized, avoiding explicit Python loops for batch and agent dimensions.

    Message Structure (per message): [position(2), intended_move(5), sender_id(6), carrying_food(1)] = 14 dims
    For 5 other agents: 5 x 14 = 70 dims per agent.
    Final output: (batch, 6, 109)
    """
    return (
        "Each agent broadcasts a message containing:\n"
        "- Its absolute (x, y) position in the grid, normalized to [0,1] (2 dims, e.g. o[...,pos_x], o[...,pos_y])\n"
        "- Its intended next move as a 5-dim one-hot vector (No-op, N, S, W, E) (5 dims)\n"
        "- Its one-hot agent ID (6 dims)\n"
        "- (Optional) 1-dim flag if the agent is carrying food or just picked up food\n"
        "Messages are sent to all other agents (broadcast, 5 senders per agent),\n"
        "and for each receiver, the 5 incoming messages are concatenated (in agent-id order, skipping self),\n"
        "then flattened (5*14=70 dims) and concatenated to the original observation (39 dims),\n"
        "yielding an output tensor of shape (batch, 6, 109)."
    )


def communication(o, pos_x_idx=0, pos_y_idx=1, carrying_food_idx=None):
    """
    Efficient message exchange for LBF (grid obs) under improved protocol.

    Args:
        o: torch.Tensor of shape (batch, 6, 39)
        pos_x_idx: Index in o[..., :] for agent's absolute x-position (float, normalized to [0,1])
        pos_y_idx: Index in o[..., :] for agent's absolute y-position (float, normalized to [0,1])
        carrying_food_idx: Index in o[..., :] for carrying food flag (binary), or None if not available
    Returns:
        messages_o: torch.Tensor of shape (batch, 6, 109)
    """
    device = o.device
    batch_size, n_agents, obs_dim = o.shape
    assert n_agents == 6 and obs_dim == 39, "Expected shape (batch, 6, 39)"
    
    # 1. Extract message fields for all agents
    pos_x = o[..., pos_x_idx:pos_x_idx+1]   # (batch, 6, 1)
    pos_y = o[..., pos_y_idx:pos_y_idx+1]   # (batch, 6, 1)
    intended_move = o[..., 27:32]           # (batch, 6, 5)
    agent_id_1hot = o[..., 33:39]           # (batch, 6, 6)
    if carrying_food_idx is not None:
        carrying_food = o[..., carrying_food_idx:carrying_food_idx+1]  # (batch, 6, 1)
        msg = th.cat([pos_x, pos_y, intended_move, agent_id_1hot, carrying_food], dim=-1)  # (batch, 6, 14)
    else:
        msg = th.cat([pos_x, pos_y, intended_move, agent_id_1hot], dim=-1)  # (batch, 6, 13)

    msg_dim = msg.shape[-1]

    # 2. Build all-to-all message matrix (excluding self)
    agent_indices = th.arange(n_agents, device=device)
    sender_indices = th.stack([
        th.cat([agent_indices[:i], agent_indices[i+1:]], dim=0)
        for i in range(n_agents)
    ], dim=0)  # (6, 5)

    batch_ar = th.arange(batch_size, device=device)[:, None, None]
    recv_ar = th.arange(n_agents, device=device)[None, :, None]
    send_ar = sender_indices[None, :, :]  # (1, 6, 5)

    batch_idx = batch_ar.expand(batch_size, n_agents, 5)
    recv_idx = recv_ar.expand(batch_size, n_agents, 5)
    send_idx = send_ar.expand(batch_size, n_agents, 5)

    messages_received = msg[batch_idx, send_idx, :]  # (batch, 6, 5, msg_dim)
    messages_received_flat = messages_received.reshape(batch_size, n_agents, -1)  # (batch, 6, 5*msg_dim)
    messages_o = th.cat([o, messages_received_flat], dim=-1)  # (batch, 6, 39+5*msg_dim)

    return messages_o
