Stage: validate_comm_function

Traceback (most recent call last):
  File "/root/epymarl/src/LLM/llm_core.py", line 740, in comm_update
    message_dim = self.code_utils.validate_communication_function(cur_module, self.test_obs, timestep_wise)
  File "/root/epymarl/src/LLM/code_utils.py", line 71, in validate_communication_function
    cur_com = module.communication(test_obs_tw)
  File "/root/epymarl/src/llm_source/Final_gpt-4.1-2025-04-14_MSE_0.05/protoss_5_vs_5/comm_update_timestep_wise2.py", line 203, in communication
    message = th.cat(msg_blocks, dim=-1)  # (batch, agents, 108)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 5 but got size 32 for tensor number 1 in the list.

Communication function code:
def communication(o):
    """
    Communication function for protoss_5_vs_5 with temporal, intent, and coordination cues.
    Args:
        o: Tensor of shape (batch, T, 5, 108), where T >= 10 (last 10 steps incl. current)
    Returns:
        Tensor of shape (batch, 5, 108 + 432) = (batch, 5, 540)
    """
    # o: (batch, timesteps, n_agents, n_obs)
    device = o.device
    batch, T, n_agents, n_obs = o.shape
    assert n_agents == 5
    assert n_obs == 108
    assert T >= 10, "Input must have at least 10 timesteps (last 10 incl. current)"

    # --- 1. Sender identity ---
    sender_id = th.eye(n_agents, device=device)[None, None, :, :]  # (1, 1, 5, 5)
    sender_id = sender_id.expand(batch, 1, n_agents, n_agents)     # (batch, 1, 5, 5)

    # --- 2. Temporal Action Summary (last 5 actions) ---
    # Action one-hot indices: 92-102 (11 actions)
    action_idx = th.arange(92, 103, device=device)  # (11,)
    # For last 5 timesteps (most recent at -1)
    last5 = o[:, -5:, :, :]  # (batch, 5, 5, 108)
    actions_last5 = last5[..., action_idx]  # (batch, 5, 5, 11)
    actions_last5 = actions_last5.permute(0,2,1,3)  # (batch, agents, 5, 11)
    actions_last5 = actions_last5.reshape(batch, n_agents, -1)  # (batch, agents, 55)

    # --- 3. Intended Action (use last action, as proxy for intent) ---
    # Take most recent step (t=-1)
    last_action = o[:, -1, :, action_idx]  # (batch, 5, 11)

    # --- 4. Recent Movement Vector (last 5, dx, dy) ---
    pos_x = o[:, -6:, :, 87]  # (batch, 6, agents)
    pos_y = o[:, -6:, :, 88]  # (batch, 6, agents)
    # Compute displacements for last 5 steps: dx = x[t] - x[t-1]
    dx = pos_x[:, 1:, :] - pos_x[:, :-1, :]  # (batch, 5, agents)
    dy = pos_y[:, 1:, :] - pos_y[:, :-1, :]  # (batch, 5, agents)
    # Flatten last 5 steps: (batch, agents, 10)
    movement = th.stack([dx, dy], dim=-1)  # (batch, 5, agents, 2)
    movement = movement.permute(0,2,1,3).reshape(batch, n_agents, -1)  # (batch, agents, 10)

    # --- 5. Target Focus (coordination cue) ---
    # For most recent step: check which attack_enemy_* action is 1
    attack_action_idx = th.arange(98, 103, device=device)  # attack_enemy_0..4
    attack_action = o[:, -1, :, attack_action_idx]  # (batch, 5, 5)
    # For each agent, argmax if any attack, else 'no target'
    attack_mask = (attack_action > 0.5)
    any_attack = attack_mask.any(dim=-1, keepdim=True)  # (batch, 5, 1)
    # One-hot: [no target, enemy_0..4], length 6
    target_focus = th.cat([
        (~any_attack).float(),  # (batch, 5, 1)
        attack_mask.float()     # (batch, 5, 5)
    ], dim=-1)  # (batch, 5, 6)

    # --- 6. Low shield/help request ---
    own_shield = o[:, -1, :, 86]  # (batch, 5)
    help_flag = (own_shield < 0.2).float().unsqueeze(-1)  # (batch, 5, 1)

    # --- 7. Freshest Enemy Sightings (per enemy: [seen_flag, health, abs_x, abs_y]) ---
    # For each enemy (0..4), for each agent, search last 10 timesteps for most recent sighting.
    enemy_offsets = th.arange(5, device=device) * 9
    health_idx = 8 + enemy_offsets
    relx_idx = 6 + enemy_offsets
    rely_idx = 7 + enemy_offsets
    shootable_idx = 4 + enemy_offsets

    own_x = o[:, -10:, :, 87].unsqueeze(-1)  # (batch, 10, agents, 1)
    own_y = o[:, -10:, :, 88].unsqueeze(-1)
    # For each enemy, get fields over last 10 steps
    enemy_health = o[:, -10:, :, health_idx]  # (batch, 10, agents, 5)
    enemy_relx = o[:, -10:, :, relx_idx]
    enemy_rely = o[:, -10:, :, rely_idx]
    enemy_shootable = o[:, -10:, :, shootable_idx]  # (batch, 10, agents, 5)
    # abs positions
    enemy_absx = own_x + enemy_relx
    enemy_absy = own_y + enemy_rely
    # seen_flag: 1 if shootable > 0
    seen_flag = (enemy_shootable > 0.5).float()  # (batch, 10, agents, 5)

    # For each enemy, for each agent, find freshest step (largest t where seen_flag==1)
    # We'll use torch.argmax on reversed time axis to find most recent
    seen_flag_rev = th.flip(seen_flag, [1])  # (batch, 10, agents, 5)
    has_seen = seen_flag_rev.any(dim=1)  # (batch, agents, 5)
    idx_rev = th.argmax(seen_flag_rev, dim=1)  # (batch, agents, 5)
    idx = 9 - idx_rev  # Map back to original time index

    # Gather freshest sighting for each enemy-agent pair
    batch_idx = th.arange(batch, device=device)[:, None, None]
    agent_idx = th.arange(n_agents, device=device)[None, :, None]
    enemy_idx = th.arange(5, device=device)[None, None, :]
    # Shape: (batch, agents, 5)
    enemy_health_fresh = enemy_health[batch_idx, idx, agent_idx, enemy_idx]
    enemy_absx_fresh = enemy_absx[batch_idx, idx, agent_idx, enemy_idx]
    enemy_absy_fresh = enemy_absy[batch_idx, idx, agent_idx, enemy_idx]
    seen_flag_fresh = has_seen.float()

    # If never seen, fill -1 for fields, 0 for flag
    missing = (1 - has_seen.float())
    enemy_health_fresh = enemy_health_fresh * has_seen + (-1.0) * missing
    enemy_absx_fresh = enemy_absx_fresh * has_seen + (-1.0) * missing
    enemy_absy_fresh = enemy_absy_fresh * has_seen + (-1.0) * missing

    # Stack per-enemy: [seen_flag, health, abs_x, abs_y] × 5 = 20
    enemy_freshest = th.stack([
        seen_flag_fresh, enemy_health_fresh, enemy_absx_fresh, enemy_absy_fresh
    ], dim=-1)  # (batch, agents, 5, 4)
    enemy_freshest = enemy_freshest.reshape(batch, n_agents, -1)  # (batch, agents, 20)

    # --- 8. Assemble message (order: sender_id, action_hist, intent, movement, target, help, enemy_freshest) ---
    msg_blocks = [
        sender_id.expand(-1, n_agents, -1, -1)[0],  # (batch, agents, 5)
        actions_last5,                              # (batch, agents, 55)
        last_action,                                # (batch, agents, 11)
        movement,                                   # (batch, agents, 10)
        target_focus,                               # (batch, agents, 6)
        help_flag,                                  # (batch, agents, 1)
        enemy_freshest                              # (batch, agents, 20)
    ]
    # All blocks: (batch, agents, X)
    message = th.cat(msg_blocks, dim=-1)  # (batch, agents, 108)

    # --- 9. Peer-to-peer exchange (broadcast, exclude self) ---
    # For each agent, receive messages from others (not self), in sender index order
    agent_indices = th.arange(n_agents, device=device)
    other_agent_indices = th.stack([
        th.cat([agent_indices[:i], agent_indices[i+1:]])
        for i in range(n_agents)
    ])  # (5, 4)
    received = th.stack(
        [message[:, other_agent_indices[i], :] for i in range(n_agents)],
        dim=1
    )  # (batch, agents, 4, 108)
    received = received.reshape(batch, n_agents, -1)  # (batch, agents, 432)

    # --- 10. Take most recent agent obs (last step) ---
    last_obs = o[:, -1, :, :]  # (batch, agents, 108)

    # --- 11. Concatenate ---
    messages_o = th.cat([last_obs, received], dim=-1)  # (batch, agents, 540)
    return messages_o