Stage: validate_comm_function

Traceback (most recent call last):
  File "/root/epymarl/src/LLM/llm_core.py", line 740, in comm_update
    message_dim = self.code_utils.validate_communication_function(cur_module, self.test_obs, timestep_wise)
  File "/root/epymarl/src/LLM/code_utils.py", line 71, in validate_communication_function
    cur_com = module.communication(test_obs_tw)
  File "/root/epymarl/src/llm_source/Final_gpt-4.1-2025-04-14_MSE_0.05/terran_5_vs_5/comm_update_timestep_wise2.py", line 173, in communication
    last_seen_t = (~enemy_shootable_bin.flip(1)).all(1).long() * last_steps + th.argmax(enemy_shootable_bin.flip(1), dim=1)
RuntimeError: "argmax_cpu" not implemented for 'Bool'

Communication function code:
def communication(o):
    """
    Input:
        o: Tensor of shape (batch, T, n_agents=5, obs_dim=98)
    Output:
        messages_o: Tensor of shape (batch, n_agents=5, 98 + 580)
    """
    # Shapes
    B, T, N, D = o.shape  # (batch, T, 5, 98)
    device = o.device

    # --- 1. Sender ID (one-hot, 5) ---
    sender_id_vec = th.eye(N, device=device).unsqueeze(0).expand(B, N, N)  # (B, N, 5)
    sender_id_vec = sender_id_vec[:, th.arange(N), :]  # (B, N, 5)

    # We'll work with the last 3 timesteps (if T < 3, pad with zeros)
    last_steps = 3
    if T < last_steps:
        pad = th.zeros(B, last_steps-T, N, D, device=device, dtype=o.dtype)
        o_padded = th.cat([pad, o], dim=1)
    else:
        o_padded = o
    o_last3 = o_padded[:, -last_steps:, :, :]  # (B, 3, N, 98)

    # --- 2. Own Recent Trajectory: last 3 abs positions, last 3 health, last 3 last-actions ---
    own_abs_x = o_last3[..., 80]  # (B, 3, N)
    own_abs_y = o_last3[..., 81]  # (B, 3, N)
    own_health = o_last3[..., 76] # (B, 3, N)
    own_last_action = o_last3[..., 82:93]  # (B, 3, N, 11)

    own_abs_x = own_abs_x.permute(0,2,1)  # (B, N, 3)
    own_abs_y = own_abs_y.permute(0,2,1)
    own_health = own_health.permute(0,2,1) # (B, N, 3)
    own_last_action = own_last_action.permute(0,2,1,3).reshape(B,N,3*11) # (B,N,33)

    # --- 3. Own Current Intent ---
    # (a) Intended movement direction: argmax over move[N/S/E/W] in current step
    move_idx = [0,1,2,3]  # north, south, east, west in obs
    move_vals = o_padded[:, -1, :, move_idx]  # (B, N, 4)
    # Add "stay" as extra dim if no movement >0
    stay_mask = (move_vals.abs().sum(-1, keepdim=True) < 1e-5).float()  # (B, N, 1)
    move_dir = th.cat([move_vals, stay_mask], dim=-1)  # (B,N,5)
    move_dir_oh = (move_dir == move_dir.max(-1, keepdim=True)[0]).float()  # (B,N,5), one-hot of intended move

    # (b) Target focus: which enemy is being attacked in last action (from current step)
    last_action = o_padded[:, -1, :, 88:93]  # (B,N,5), attack_enemy_0~4
    attack_focus_oh = (last_action == last_action.max(-1, keepdim=True)[0]).float()  # (B,N,5), one-hot of focused enemy

    # (c) Support focus: for Medivac, which ally is being supported (attack action not available, so use "move toward" as proxy)
    # We'll define support focus as the closest ally (excluding self) if unit type is Medivac, else all zeros.
    is_medivac = (o_padded[:, -1, :, 77:80] == th.tensor([0,0,1], device=device)).all(-1)  # (B,N)
    # For each agent, find closest other agent (using ally_*_distance)
    ally_distance_idx = [45,53,61,69]
    ally_distances = th.stack([o_padded[:, -1, :, idx] for idx in ally_distance_idx], dim=-1)  # (B,N,4)
    # Insert inf for self
    inf = th.full_like(ally_distances[:,:,0:1], 1e6)
    ally_distances_full = th.cat([inf, ally_distances], dim=-1)  # (B,N,5)
    # For each agent, set own index to inf
    for i in range(N):
        ally_distances_full[:,i,i] = 1e6
    # Get min index (closest ally)
    min_idx = ally_distances_full.min(-1)[1]  # (B,N)
    support_focus_oh = th.zeros(B,N,5,device=device)
    support_focus_oh.scatter_(-1, min_idx.unsqueeze(-1), 1.0)
    support_focus_oh = support_focus_oh * is_medivac.unsqueeze(-1).float()  # Only Medivac have nonzero

    # --- 4. Observed Enemy Movement & Belief Summary ---
    # For each enemy (0-4), last 3 timesteps: rel pos x/y, health, unit type, shootable
    enemy_idx = [0,1,2,3,4]
    enemy_relx_idx = [6,14,22,30,38]
    enemy_rely_idx = [7,15,23,31,39]
    enemy_health_idx = [8,16,24,32,40]
    enemy_utype_idx = [
        [9,10,11], [17,18,19], [25,26,27], [33,34,35], [41,42,43]
    ]
    enemy_shootable_idx = [4,12,20,28,36]

    # For each enemy, get last 3 rel positions, health, type, shootable mask
    enemy_relx = th.stack([o_last3[..., idx] for idx in enemy_relx_idx],dim=-1) #(B,3,N,5)
    enemy_rely = th.stack([o_last3[..., idx] for idx in enemy_rely_idx],dim=-1)
    enemy_health = th.stack([o_last3[..., idx] for idx in enemy_health_idx],dim=-1)
    enemy_shootable = th.stack([o_last3[..., idx] for idx in enemy_shootable_idx],dim=-1)  # (B,3,N,5)
    enemy_utype = th.stack([o_last3[..., idxs] for idxs in enemy_utype_idx], dim=-2) #(B,3,N,5,3)

    # For each agent, for each enemy, find last timestep in [t-1, t-2, t-3] when enemy was observed (shootable>0)
    # We'll treat "shootable" as proxy for observable
    enemy_shootable_bin = (enemy_shootable > 0.5)
    # last_seen_t: (B,N,5): timestep index (0=most recent) of last seen, or 3 if never seen
    last_seen_t = (~enemy_shootable_bin.flip(1)).all(1).long() * last_steps + th.argmax(enemy_shootable_bin.flip(1), dim=1)
    last_seen_t = last_steps-1 - last_seen_t  # convert to index in [0,2]
    # For each agent, enemy, get last relx, rely, health, type at last seen
    batch_idx = th.arange(B, device=device)[:,None,None].expand(B,N,5)
    agent_idx = th.arange(N, device=device)[None,:,None].expand(B,N,5)
    enemy_idx_t = th.arange(5, device=device)[None,None,:].expand(B,N,5)
    seen_t = last_seen_t.clamp(0,last_steps-1)
    relx_last = enemy_relx[batch_idx, seen_t, agent_idx, enemy_idx_t]  # (B,N,5)
    rely_last = enemy_rely[batch_idx, seen_t, agent_idx, enemy_idx_t]
    health_last = enemy_health[batch_idx, seen_t, agent_idx, enemy_idx_t]
    utype_last = enemy_utype[batch_idx, seen_t, agent_idx, enemy_idx_t]  # (B,N,5,3)
    # For each enemy, compute rel movement delta over last 2 steps where observed
    relx_prev = enemy_relx[batch_idx, (seen_t-1).clamp(0,last_steps-1), agent_idx, enemy_idx_t]
    rely_prev = enemy_rely[batch_idx, (seen_t-1).clamp(0,last_steps-1), agent_idx, enemy_idx_t]
    dx = relx_last - relx_prev
    dy = rely_last - rely_prev
    # last_seen_time normalized: (seen_t+1)/last_steps
    last_seen_time = (seen_t.float()+1)/last_steps  # (B,N,5)

    # For unobserved enemies (never observed in last 3 steps), share inferred pos mean/std (estimate from all available history in window)
    enemy_relx_hist = enemy_relx  # (B,3,N,5)
    enemy_rely_hist = enemy_rely
    # Mask for whether enemy ever observed in last 3
    ever_obs = (enemy_shootable_bin.sum(1) > 0)  # (B,N,5)
    # Compute mean and std over last 3 steps for each enemy if observed, else zeros
    relx_mean = (enemy_relx_hist * enemy_shootable_bin.float()).sum(1) / (enemy_shootable_bin.float().sum(1)+1e-6)
    rely_mean = (enemy_rely_hist * enemy_shootable_bin.float()).sum(1) / (enemy_shootable_bin.float().sum(1)+1e-6)
    relx_sq = (enemy_relx_hist**2 * enemy_shootable_bin.float()).sum(1) / (enemy_shootable_bin.float().sum(1)+1e-6)
    rely_sq = (enemy_rely_hist**2 * enemy_shootable_bin.float()).sum(1) / (enemy_shootable_bin.float().sum(1)+1e-6)
    relx_std = (relx_sq - relx_mean**2).clamp(min=0).sqrt()
    rely_std = (rely_sq - rely_mean**2).clamp(min=0).sqrt()
    relx_mean = relx_mean * ever_obs.float()
    rely_mean = rely_mean * ever_obs.float()
    relx_std = relx_std * ever_obs.float()
    rely_std = rely_std * ever_obs.float()
    # For never observed, all zeros
    # Stack inferred pos/uncertainty (mean, std)
    inferred_pos = th.stack([relx_mean, rely_mean, relx_std, rely_std], dim=-1)  # (B,N,5,4)

    # For each enemy, for each agent, select:
    # If observed at least once, report (dx,dy), health_last, utype_last, last_seen_time; else inferred_pos
    obs_fields = th.cat([
        dx.unsqueeze(-1), dy.unsqueeze(-1), health_last.unsqueeze(-1), utype_last, last_seen_time.unsqueeze(-1)
    ], dim=-1)  # (B,N,5,8)
    # Mask: (B,N,5,1) True if observed at least once
    obs_mask = ever_obs.unsqueeze(-1)
    # Compose final enemy info: if observed, obs_fields; else inferred_pos + zeros for rest
    enemy_info = th.where(obs_mask, obs_fields, th.cat([inferred_pos, th.zeros(B,N,5,4,device=device)],dim=-1))  # (B,N,5,8)
    # For unobserved, fill remaining 4 dims with zeros
    # Reshape to (B,N,5*8=40)
    enemy_info_flat = enemy_info.reshape(B,N,5*8)

    # --- 5. Observed Ally Movement Summary (for each other ally) ---
    # For each other ally: last rel movement delta (dx,dy), last health, unit type, last seen
    ally_relx_idx = [46,54,62,70]
    ally_rely_idx = [47,55,63,71]
    ally_health_idx = [48,56,64,72]
    ally_utype_idx = [
        [49,50,51], [57,58,59], [65,66,67], [73,74,75]
    ]
    ally_visible_idx = [44,52,60,68]
    # For each agent, for each other agent
    ally_relx = th.stack([o_last3[..., idx] for idx in ally_relx_idx],dim=-1) #(B,3,N,4)
    ally_rely = th.stack([o_last3[..., idx] for idx in ally_rely_idx],dim=-1)
    ally_health = th.stack([o_last3[..., idx] for idx in ally_health_idx],dim=-1)
    ally_utype = th.stack([o_last3[..., idxs] for idxs in ally_utype_idx], dim=-2) #(B,3,N,4,3)
    ally_visible = th.stack([o_last3[..., idx] for idx in ally_visible_idx],dim=-1) #(B,3,N,4)
    # For each agent, for each other, get last seen in last 3
    ally_visible_bin = (ally_visible > 0.5)
    last_seen_t_ally = (~ally_visible_bin.flip(1)).all(1).long() * last_steps + th.argmax(ally_visible_bin.flip(1), dim=1)
    last_seen_t_ally = last_steps-1 - last_seen_t_ally  # (B,N,4)
    # Indexing
    batch_idx4 = th.arange(B, device=device)[:,None,None].expand(B,N,4)
    agent_idx4 = th.arange(N, device=device)[None,:,None].expand(B,N,4)
    ally_idx4 = th.arange(4, device=device)[None,None,:].expand(B,N,4)
    seen_t_ally = last_seen_t_ally.clamp(0,last_steps-1)
    relx_last_ally = ally_relx[batch_idx4, seen_t_ally, agent_idx4, ally_idx4]
    rely_last_ally = ally_rely[batch_idx4, seen_t_ally, agent_idx4, ally_idx4]
    health_last_ally = ally_health[batch_idx4, seen_t_ally, agent_idx4, ally_idx4]
    utype_last_ally = ally_utype[batch_idx4, seen_t_ally, agent_idx4, ally_idx4]  # (B,N,4,3)
    relx_prev_ally = ally_relx[batch_idx4, (seen_t_ally-1).clamp(0,last_steps-1), agent_idx4, ally_idx4]
    rely_prev_ally = ally_rely[batch_idx4, (seen_t_ally-1).clamp(0,last_steps-1), agent_idx4, ally_idx4]
    dx_ally = relx_last_ally - relx_prev_ally
    dy_ally = rely_last_ally - rely_prev_ally
    last_seen_time_ally = (seen_t_ally.float()+1)/last_steps  # (B,N,4)
    # Compose (dx,dy,health,utype(3),last_seen_time) = 7 dims × 4 = 28
    ally_info = th.cat([
        dx_ally.unsqueeze(-1), dy_ally.unsqueeze(-1), health_last_ally.unsqueeze(-1), utype_last_ally, last_seen_time_ally.unsqueeze(-1)
    ], dim=-1) # (B,N,4,7)
    ally_info_flat = ally_info.reshape(B,N,4*7)

    # --- Compose message: (5+6+3+33+5+5+5+40+28) = 130 dims ---
    msg_parts = [
        sender_id_vec,                # (B,N,5)
        own_abs_x,                    # (B,N,3)
        own_abs_y,                    # (B,N,3)
        own_health,                   # (B,N,3)
        own_last_action,              # (B,N,33)
        move_dir_oh,                  # (B,N,5)
        attack_focus_oh,              # (B,N,5)
        support_focus_oh,             # (B,N,5)
        enemy_info_flat,              # (B,N,40)
        ally_info_flat                # (B,N,28)
    ]
    message = th.cat(msg_parts, dim=-1)  # (B,N,130)

    # --- Distribute messages: for each agent, collect messages from all other agents (not self) ---
    message_exp = message.unsqueeze(1).expand(B, N, N, message.shape[-1])  # (B, receiver, sender, 130)
    mask = ~th.eye(N, dtype=th.bool, device=device).unsqueeze(0).expand(B, N, N)  # (B, N, N)
    msg_others = message_exp[mask].view(B, N, N-1, message.shape[-1])      # (B, N, 4, 130)
    msg_others_flat = msg_others.reshape(B, N, 4*message.shape[-1])        # (B, N, 520)

    # --- Concatenate to current observation (last step only) ---
    o_last = o_padded[:, -1, :, :]   # (B, N, 98)
    messages_o = th.cat([o_last, msg_others_flat], dim=-1)    # (B, N, 98+520=618)

    return messages_o