Stage: validate_comm_function

Traceback (most recent call last):
  File "/root/epymarl/src/LLM/llm_core.py", line 740, in comm_update
    message_dim = self.code_utils.validate_communication_function(cur_module, self.test_obs, timestep_wise)
  File "/root/epymarl/src/LLM/code_utils.py", line 71, in validate_communication_function
    cur_com = module.communication(test_obs_tw)
  File "/root/epymarl/src/llm_source/Final_gpt-4.1-2025-04-14_MSE_0.05/zerg_5_vs_5/comm_update_timestep_wise2.py", line 188, in communication
    last_type = gather_last_seen(unit_type_win, sel_idx.unsqueeze(-1).expand(*sel_idx.shape,3))
  File "/root/epymarl/src/llm_source/Final_gpt-4.1-2025-04-14_MSE_0.05/zerg_5_vs_5/comm_update_timestep_wise2.py", line 183, in gather_last_seen
    out = arr.gather(3, sel_idx_exp.unsqueeze(-1) if arr.ndim==4 else sel_idx_exp.unsqueeze(-1).expand(*sel_idx.shape,arr.shape[-1]))
RuntimeError: expand(torch.LongTensor{[32, 5, 5, 3, 3, 1]}, size=[32, 5, 5, 3, 3]): the number of sizes provided (5) must be greater or equal to the number of dimensions in the tensor (6)

Communication function code:
def communication(o):
    """
    Implements the temporal-intent-aware protocol.
    Inputs:
        o: (batch, T, n_agents, obs_dim=98) -- contains previous 10 steps up to current for each agent
    Returns:
        enhanced_o: (batch, n_agents, obs_dim + n_agents*message_dim) -- for current time step only
    """
    # --- Setup ---
    batch, T, n_agents, obs_dim = o.shape
    device = o.device
    window = 10  # Use last 10 steps (including current)
    assert T >= window, "Observation buffer must contain at least 10 steps"
    step_idxs = th.arange(T-window, T, device=device)  # shape (window,)

    # === 1. Sender ID ===
    sender_ids = th.eye(n_agents, device=device).unsqueeze(0).repeat(batch, 1, 1)  # (batch, n_agents, 5)

    # === 2. Temporal Window: own state, last action, intent ===
    # Indices for own state
    idx_own_health = 76
    idx_own_unit_type = [77, 78, 79]
    idx_own_pos_x = 80
    idx_own_pos_y = 81
    idx_last_action = [82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]  # 11 possible, but we will use 7 (see below)

    # For each step in window, gather own state
    o_win = o[:, step_idxs, :, :]  # (batch, window, n_agents, obs_dim)
    # (batch, n_agents, window, field)
    own_health_win = o_win[..., idx_own_health]  # (batch, window, n_agents)
    own_pos_x_win = o_win[..., idx_own_pos_x]
    own_pos_y_win = o_win[..., idx_own_pos_y]
    own_unit_type_win = o_win[..., idx_own_unit_type]  # (batch, window, n_agents, 3)
    # Last action: use 7 actions: no-op (82), stop (83), move_north (84), move_south (85), move_east (86), move_west (87), attack_any (any of 88-92)
    last_action_all = o_win[..., idx_last_action]  # (batch, window, n_agents, 11)
    # Reduce to 7:
    # attack_any = max of attack_enemy_0-4 (88-92)
    attack_any = last_action_all[..., 6:11].max(dim=-1, keepdim=True)[0]  # (batch, window, n_agents, 1)
    last_action_7 = th.cat([
        last_action_all[..., 0:6],  # no-op, stop, move_north, move_south, move_east, move_west
        attack_any                  # attack_any
    ], dim=-1)  # (batch, window, n_agents, 7)

    # Inferred high-level intent (0=unknown/other, 1=attack, 2=support, 3=move/retreat):
    # - If attack_any==1: attack (1)
    # - If move_*==1: move/retreat (3)
    # - If own_unit_type is Medivac (unit_type bits: e.g., [0,0,1]) and last action is stop or move: support (2)
    # - Else: unknown (0)
    # Medivac: unit_type bits, let's assume [0,0,1] indicates Medivac (check this for SMACv2)
    is_medivac = (own_unit_type_win[..., 2] > 0.5) & (own_unit_type_win[..., 0:2].sum(-1) < 0.5)  # (batch, window, n_agents)
    attack_flag = last_action_7[..., 6] > 0.5  # attack_any
    move_flag = last_action_7[..., 2:6].sum(-1) > 0.5
    stop_flag = last_action_7[..., 1] > 0.5
    # Compose intent: one-hot (3 bits)
    intent = th.zeros(o_win.shape[0], window, n_agents, 3, device=device)
    # attack
    intent[..., 0] = attack_flag.float()
    # support: medivac & (stop or move)
    intent[..., 1] = (is_medivac & (move_flag | stop_flag)).float()
    # move/retreat
    intent[..., 2] = ((~is_medivac) & move_flag).float()
    # If none: all zeros
    # (batch, window, n_agents, 3)

    # Stack temporal window per agent: [x, y, health, unit_type(3), last_action(7), intent(3)] * window
    own_traj = th.cat([
        own_pos_x_win.transpose(1,2)[..., None],    # (batch, n_agents, window, 1)
        own_pos_y_win.transpose(1,2)[..., None],
        own_health_win.transpose(1,2)[..., None],
        own_unit_type_win.transpose(1,2),           # (batch, n_agents, window, 3)
        last_action_7.transpose(1,2),               # (batch, n_agents, window, 7)
        intent.transpose(1,2)                       # (batch, n_agents, window, 3)
    ], dim=-1)  # (batch, n_agents, window, 16)
    own_traj_flat = own_traj.reshape(batch, n_agents, window*16)  # (batch, n_agents, 160)

    # === 3. Last-seen enemy info ===
    # Enemy indices (for relative positions etc.)
    enemy_offsets = [4 + i*9 for i in range(5)]
    idx_enemy_shootable = [offset for offset in enemy_offsets]
    idx_enemy_rel_x = [offset+2 for offset in enemy_offsets]
    idx_enemy_rel_y = [offset+3 for offset in enemy_offsets]
    idx_enemy_health = [offset+4 for offset in enemy_offsets]
    idx_enemy_unit_type = [[offset+5, offset+6, offset+7] for offset in enemy_offsets]
    idx_own_pos_x = 80
    idx_own_pos_y = 81

    # For each agent and each enemy, get for each step in window: shootable, rel_x, rel_y, health, unit_type
    # (batch, window, n_agents, 5)
    shootable_win = th.stack([o_win[..., idx] for idx in idx_enemy_shootable], dim=-1)
    rel_x_win = th.stack([o_win[..., idx] for idx in idx_enemy_rel_x], dim=-1)
    rel_y_win = th.stack([o_win[..., idx] for idx in idx_enemy_rel_y], dim=-1)
    health_win = th.stack([o_win[..., idx] for idx in idx_enemy_health], dim=-1)
    unit_type_win = th.stack([o_win[..., idxs] for idxs in idx_enemy_unit_type], dim=-2)  # (batch, window, n_agents, 5, 3)
    # (batch, window, n_agents, 1)
    own_pos_x_win = o_win[..., idx_own_pos_x]
    own_pos_y_win = o_win[..., idx_own_pos_y]
    # Absolute positions for each enemy at each window step
    abs_x_win = own_pos_x_win.unsqueeze(-1) + rel_x_win
    abs_y_win = own_pos_y_win.unsqueeze(-1) + rel_y_win

    # For each enemy, for each agent, get last time step seen, last known (x,y,health,type)
    # shootable_win: (batch, window, n_agents, 5)
    is_visible = (shootable_win > 0.5)
    # For "last time step seen": 0 if visible at current step (window-1), else how many steps ago (up to window), 11 if never seen
    last_seen_idx = is_visible.flip(1).float().argmax(dim=1)  # (batch, n_agents, 5) -- index in reversed window, 0=latest
    # If never seen: all False, argmax returns 0, but is_visible.sum(1)==0
    never_seen = (is_visible.sum(1) == 0)  # (batch, n_agents, 5)
    last_seen_step = th.where(never_seen, th.full_like(last_seen_idx, 11), last_seen_idx)  # 11=never seen
    # For each agent, enemy, select last seen field from window
    # For each field: use last seen index (from back), so index = window-1 - last_seen_idx
    sel_idx = (window - 1 - last_seen_idx).clamp(0, window-1)  # (batch, n_agents, 5)
    # Gather last seen info