import torch
from torch import nn
from Language import *
from Observation import *
from Environment import *

# ======== Localization Agents =========

class PosAgentA(nn.Module):
    def __init__(self):
        super(PosAgentA, self).__init__()
        self.language_model = Language(role='A')
        self.observe_model = PosObservation(role='A')
        self.obs_token = nn.Embedding(1, Param.emb_size)

    def forward(self, history, observation, choose_method="sample"):
        """
        :param history: (batch, n, sent len)
        :param observation: (batch, 36, feat len)
        :param choose_method: sample or greedy
        :return:
        """
        # -------- speak ----------
        msg, msg_prob = self.speak(history, observation, choose_method=choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens['<msg A>']['pos'])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, msg, msg_prob

    def speak(self, history, observation, choose_method="sample"):
        """
        :param history: (batch, history len)
        :param observation: (batch, 36, feat len)
        :param choose_method: sample or greedy
        :return:
        """
        history_vec = self.language_model.encode(history)  # (batch, history len, emb dim)
        if Param.is_spread_token is True:
            observe_vec = self.observe_model.encode(observation)  # (batch, 36, emb dim)
        else:
            observe_vec = self.observe_model(observation).unsqueeze(1)  # (batch, 1, emb dim)
        cur_obs_token = self.obs_token(torch.zeros((observe_vec.shape[0], 1)).to(torch.int64))  # (batch, 1, emb dim)
        context_vec = torch.cat([history_vec, cur_obs_token, observe_vec], dim=1)  # (batch, history len + 2, emb dim)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method)
        return msg, msg_prob

    def cal_loss_alg_blu(self, msg_a_prob, rewards):
        loss_msg = - torch.sum(msg_a_prob * rewards.unsqueeze(2)) / msg_a_prob.shape[2]
        return loss_msg

    def cal_loss(self, msg_a_prob, rewards):
        loss_msg = - torch.sum(msg_a_prob * rewards.unsqueeze(2)) / msg_a_prob.shape[2]
        return loss_msg

# ========= Guidance Agents ============

class GuideAgentA(nn.Module):
    def __init__(self):
        super(GuideAgentA, self).__init__()
        self.language_model = Language(role='A')
        self.observe_model = PosObservation(role='A')
        self.neigh_obs_token = nn.Embedding(1, Param.emb_size)
        self.act_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None
        self.cand_act_emb_dict = None

    def forward(self, history, states, env: FullEnvDataset, choose_method="sample"):
        """
        :param history: (batch, history len)
        :param states: scan_id, view_id, angles
        :param env:
        :param choose_method: sample or greedy
        :return:
        """
        # ------ act ---------
        scan_ids, view_ids, angles = zip(*states)
        act_idx, act_prob, cand_neigh_ids = self.act(history, states, env, choose_method=choose_method)
        act_idx_list = act_idx.tolist()
        new_view_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(act_idx_list, cand_neigh_ids)]
        new_states = zip(scan_ids, new_view_ids, angles)
        # ----- speak -------
        msg, msg_prob = self.speak(history, new_states, env, choose_method=choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg A>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, msg, msg_prob, new_view_ids, act_prob

    def act(self, history, states, env: FullEnvDataset, choose_method="sample"):
        guess_score, cand_neigh_ids, _ = self.guess(history, states, env)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample": idx = sampler.sample()
        else: idx = torch.argmax(guess_score_softmax, dim=1)
        return idx, sampler.log_prob(idx), cand_neigh_ids

    def guess(self, history, states, env: FullEnvDataset):
        scan_ids, view_ids, _ = zip(*states)
        history_vec = self.language_model.encode(history)
        if Param.is_spread_token is True:
            cand_obs_vec_extend, cand_view_ids = self.collect_candidates(scan_ids, view_ids)
            cand_obs_vec = cand_obs_vec_extend.mean(dim=2)
        else:
            cand_obs_vec_extend, cand_view_ids = self.collect_candidates(scan_ids, view_ids)
            cand_obs_vec = cand_obs_vec_extend
        cand_mask, max_neigh_len = [], cand_obs_vec.shape[1]
        for cur_cand_view_ids in cand_view_ids:
            cand_mask.append([0] * len(cur_cand_view_ids) + [1] * (max_neigh_len - len(cur_cand_view_ids)))
        cand_mask = torch.from_numpy(np.array(cand_mask))
        # ----- guess -----
        history_emb = torch.mean(history_vec, dim=1)  # (batch, emb dim)
        score = torch.bmm(cand_obs_vec, history_emb.unsqueeze(2)).squeeze(2)
        score[cand_mask == 1] = -math.inf
        if history.shape == (Param.batch_size, 0):
            score = torch.full_like(score, -math.inf)
            score[:, 0] = 0
        return score, cand_view_ids, cand_obs_vec_extend  # (batch, neigh len, 6, emb dim)

    def speak(self, history, states, env: FullEnvDataset, choose_method="sample"):
        history_vec = self.language_model.encode(history)
        guess_score, cand_neigh_ids, cand_obs_vec = self.guess(history, states, env)
        score = torch.softmax(guess_score, dim=1)
        if Param.is_spread_token is True:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(),
                                       cand_obs_vec.view((Param.batch_size, Param.max_neigh_num + 1,
                                                          Param.max_path_img_num * Param.emb_size))).squeeze(1)  # (batch, 6 * emb dim)
            attn_neigh_vec = attn_neigh_vec.view((Param.batch_size, Param.max_path_img_num, Param.emb_size))
        else:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(), cand_obs_vec)
        cur_neigh_obs_token = self.neigh_obs_token(torch.zeros((history_vec.shape[0], 1)).to(torch.int64))
        context_vec = torch.cat([history_vec, cur_neigh_obs_token, attn_neigh_vec], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method)
        return msg, msg_prob

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict, self.cand_act_emb_dict = {}, {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def collect_candidates(self, scan_ids, view_ids):
        neighbor_ids, neighbor_embs = [], []
        for scan_id, view_id in zip(scan_ids, view_ids):
            cur_dict = self.cand_act_emb_dict[scan_id][view_id]
            neighbor_ids.append(cur_dict["neigh_ids"])
            if Param.is_spread_token is True:
                neighbor_embs.append(cur_dict["neigh_emb"])  # (max neigh num, 6, emb dim)
            else:
                neighbor_embs.append(torch.mean(cur_dict["neigh_emb"], dim=1))  # (neigh neigh num, emb dim)
        neighbor_embs = torch.stack(neighbor_embs, dim=0) # (batch, max neigh num, 6, emb dim)
        if Param.is_spread_token is True:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1, Param.max_path_img_num).to(torch.int64))  # (batch, 1, 6, emb dim)
        else:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1).to(torch.int64))
        neighbor_embs = torch.cat([stop_tokens, neighbor_embs], dim=1)
        neighbor_ids = [[cur_view_id] + cur_neigh_ids for cur_view_id, cur_neigh_ids in zip(view_ids, neighbor_ids)]
        return neighbor_embs.to(torch.float32), neighbor_ids

    def cal_loss_alu_blg(self, msg_a_prob, act_a_prob, rewards):
        loss_msg = - torch.sum(msg_a_prob * torch.zeros_like(msg_a_prob))
        loss_act = - torch.sum(act_a_prob * rewards)
        return loss_msg + loss_act

    def cal_loss(self, msg_a_prob, act_a_prob, rewards):
        shift_reward = torch.cat([rewards[:, 1:], torch.zeros((Param.batch_size, 1))], dim=1)
        loss_msg = - torch.sum(msg_a_prob * shift_reward.unsqueeze(2)) / msg_a_prob.shape[2]
        loss_act = - torch.sum(act_a_prob * rewards)
        return loss_msg + loss_act


class GuideMultiStepAgentA(nn.Module):
    def __init__(self):
        super(GuideMultiStepAgentA, self).__init__()
        self.language_model = Language(role='A')
        self.observe_model = PosObservation(role='A')
        self.neigh_obs_token = nn.Embedding(1, Param.emb_size)
        self.path_obs_token = nn.Embedding(1, Param.emb_size)
        self.act_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None
        self.cand_act_emb_dict = None

    def forward(self, history, history_path_vec, states, env: FullEnvDataset, choose_method="sample"):
        # -------- act ------------
        scan_ids, view_ids, angles = zip(*states)
        act_idx, act_prob, cand_neigh_ids, cand_obs_vec = self.act(history, history_path_vec, states, env, choose_method)
        act_idx_list = act_idx.tolist()
        pred_act_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(act_idx_list, cand_neigh_ids)]
        new_states = zip(scan_ids, pred_act_ids, angles)
        if Param.is_spread_token is True:
            cur_path_vec = cand_obs_vec[torch.arange(0, Param.batch_size), act_idx, :, :]  # (batch, 6, emb dim)
        else:
            cur_path_vec = cand_obs_vec[torch.arange(0, Param.batch_size), act_idx, :]  # (batch, emb dim)
        new_history_path_vec = torch.cat([history_path_vec, cur_path_vec.unsqueeze(1)], dim=1) if history_path_vec is not None \
            else cur_path_vec.unsqueeze(1)  # (batch, step, 6, emb dim) or (batch, step, emb dim)
        # ------- speak -----------
        msg, msg_prob, is_speak = self.speak(history, new_history_path_vec, new_states, env, choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg A>"]["pos"])
        if is_speak is not None:
            msg[is_speak == False, :] = Param.tokens["<msg A>"]["start pos"]
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, new_history_path_vec, msg, msg_prob, pred_act_ids, act_prob, is_speak

    def act(self, history, history_path_vec, states, env: FullEnvDataset, choose_method="sample"):
        guess_score, cand_neigh_ids, cand_obs_vec = self.guess(history, history_path_vec, states, env)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample": idx = sampler.sample()
        else: idx = torch.argmax(guess_score_softmax, dim=1)
        # NOTE add
        # if self.training is False:
        #     cur_is_sure = self.is_move_sure(guess_score_softmax)
        #     idx[cur_is_sure == False] = 0
        return idx, sampler.log_prob(idx), cand_neigh_ids, cand_obs_vec

    def guess(self, history, history_path_vec, states, env: FullEnvDataset):
        """
        :param history: (batch, history len)
        :param history_path_vec: (batch, steps, (6,) emb dim)
        :param states:
        :param env:
        :return:
        """
        scan_ids, view_ids, _ = zip(*states)
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_anti_mask = torch.full_like(history, 1)
        history_anti_mask[history_mask == 1] = 0
        history_vec = self.language_model.encode(history, mask=history_mask)
        history_emb = torch.sum(history_vec * history_anti_mask.unsqueeze(2), dim=1) / torch.sum(history_anti_mask, dim=1).unsqueeze(1)  # (batch, emb dim)
        if Param.is_spread_token is True:
            cand_obs_vec_extend, cand_view_ids = self.collect_candidates(scan_ids, view_ids)  # (batch, 6, emb dim)
            cand_obs_vec = cand_obs_vec_extend.mean(dim=2)
            history_path_emb = torch.mean(torch.mean(history_path_vec, dim=2), dim=1) if history_path_vec is not None else None
        else:
            cand_obs_vec_extend, cand_view_ids = self.collect_candidates(scan_ids, view_ids)  # (batch, emb dim)
            cand_obs_vec = cand_obs_vec_extend
            history_path_emb = torch.mean(history_path_vec, dim=1) if history_path_vec is not None else None  # (batch, emb dim)
        cand_mask, max_neigh_len = [], cand_obs_vec.shape[1]
        for cur_cand_view_ids in cand_view_ids:
            cand_mask.append([0] * len(cur_cand_view_ids) + [1] * (max_neigh_len - len(cur_cand_view_ids)))
        cand_mask = torch.from_numpy(np.array(cand_mask))
        context_emb = history_emb - history_path_emb if history_path_emb is not None else history_emb  # (batch, emb dim)
        # context_emb = history_emb
        score = torch.bmm(cand_obs_vec, context_emb.unsqueeze(2)).squeeze(2)
        score[cand_mask == 1] = -math.inf
        return score, cand_view_ids, cand_obs_vec_extend

    def speak(self, history, history_path_vec, states, env: FullEnvDataset, choose_method="sample"):
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_vec = self.language_model.encode(history, mask=history_mask)
        guess_score, cand_neigh_ids, cand_obs_vec = self.guess(history, history_path_vec, states, env)
        score = torch.softmax(guess_score, dim=1)
        is_sure_mask = self.is_sure(score)
        is_speak_mask = torch.full_like(is_sure_mask, False)
        is_speak_mask[is_sure_mask == False] = True
        if Param.is_all_speak is True:
            is_speak_mask = torch.full_like(is_speak_mask, True)
        elif Param.is_all_no_speak is True:
            is_speak_mask = torch.full_like(is_speak_mask, False)
        history_path_emb = torch.mean(history_path_vec, dim=1) if history_path_vec is not None else None  # (batch, 6, emb dim) or (batch, emb dim)
        cand_path_vec = cand_obs_vec + history_path_emb.unsqueeze(dim=1) if history_path_emb is not None else cand_obs_vec  # (batch, neigh len, 6, emb dim) or (batch, neigh len, emb dim)
        # cand_path_vec = cand_obs_vec
        if Param.is_spread_token is True:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(),
                                       cand_path_vec.view((Param.batch_size, Param.max_neigh_num + 1,
                                                           Param.max_path_img_num * Param.emb_size))).squeeze(1)  # (batch, 6 * emb dim)
            attn_neigh_vec = attn_neigh_vec.view((Param.batch_size, Param.max_path_img_num, Param.emb_size))
        else:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(), cand_path_vec)
        cur_neigh_obs_token = self.neigh_obs_token(torch.zeros((history_vec.shape[0], 1)).to(torch.int64))
        context_vec = torch.cat([history_vec, cur_neigh_obs_token, attn_neigh_vec], dim=1)
        history_mask = torch.cat([history_mask, torch.zeros((Param.batch_size, 1 + attn_neigh_vec.shape[1]))], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method, mem_mask=history_mask)
        return msg, msg_prob, is_speak_mask

    def is_sure(self, guess_score):
        max_score_idx = torch.argmax(guess_score, dim=1)
        max_score = guess_score[torch.arange(0, Param.batch_size), max_score_idx]
        guess_score_copy = guess_score.clone()
        guess_score_copy[torch.arange(0, Param.batch_size), max_score_idx] = 0
        sec_max_score = torch.max(guess_score_copy, dim=1).values
        is_sure = ((max_score - sec_max_score) > 0.99)
        return is_sure

    def is_move_sure(self, guess_score):
        max_score_idx = torch.argmax(guess_score, dim=1)
        max_score = guess_score[torch.arange(0, Param.batch_size), max_score_idx]
        guess_score_copy = guess_score.clone()
        guess_score_copy[torch.arange(0, Param.batch_size), max_score_idx] = 0
        sec_max_score = torch.max(guess_score_copy, dim=1).values
        is_sure = ((max_score - sec_max_score) > 0.2)
        return is_sure

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict, self.cand_act_emb_dict = {}, {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def collect_candidates(self, scan_ids, view_ids):
        neighbor_ids, neighbor_embs = [], []
        for scan_id, view_id in zip(scan_ids, view_ids):
            cur_dict = self.cand_act_emb_dict[scan_id][view_id]
            neighbor_ids.append(cur_dict["neigh_ids"])
            if Param.is_spread_token is True:
                neighbor_embs.append(cur_dict["neigh_emb"])  # (max neigh num, 6, emb dim)
            else:
                neighbor_embs.append(torch.mean(cur_dict["neigh_emb"], dim=1))  # (neigh neigh num, emb dim)
        neighbor_embs = torch.stack(neighbor_embs, dim=0) # (batch, max neigh num, 6, emb dim)
        if Param.is_spread_token is True:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1, Param.max_path_img_num).to(torch.int64))  # (batch, 1, 6, emb dim)
        else:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1).to(torch.int64))
        neighbor_embs = torch.cat([stop_tokens, neighbor_embs], dim=1)
        neighbor_ids = [[cur_view_id] + cur_neigh_ids for cur_view_id, cur_neigh_ids in zip(view_ids, neighbor_ids)]
        return neighbor_embs.to(torch.float32), neighbor_ids

    def cal_loss(self, msg_a_prob, act_a_prob, rewards_msg_a, rewards_act_a):
        loss_msg = - torch.sum(msg_a_prob * rewards_msg_a.unsqueeze(2)) / msg_a_prob.shape[2]
        loss_act = - torch.sum(act_a_prob * rewards_act_a)
        return loss_msg + loss_act

    def cal_loss_alu_blg(self, msg_a_prob, act_a_prob, rewards_msg_a, rewards_act_a):
        loss_msg = - torch.sum(msg_a_prob * torch.zeros_like(msg_a_prob))
        loss_act = - torch.sum(act_a_prob * rewards_act_a)
        return loss_msg + loss_act


# ========= Navigation Agents ============

class NavAgentA(nn.Module):
    def __init__(self):
        super(NavAgentA, self).__init__()
        self.language_model = Language(role='A')
        self.observe_model = PosObservation(role='A')
        self.obs_token = nn.Embedding(1, Param.emb_size)
        self.neigh_obs_token = nn.Embedding(1, Param.emb_size)
        self.act_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None
        self.cand_act_emb_dict = None

    def forward(self, history, states, env: FullEnvDataset, choose_method="sample"):
        # --------- act -----------
        scan_ids, view_ids, angles = zip(*states)
        act_idx, act_prob, cand_neigh_ids = self.act(history, states, env, choose_method=choose_method)
        act_idx_list = act_idx.tolist()
        new_view_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(act_idx_list, cand_neigh_ids)]
        new_states = zip(scan_ids, new_view_ids, angles)
        # -------- speak ----------
        msg, msg_prob = self.speak(history, new_states, env, choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg A>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, msg, msg_prob, new_view_ids, act_prob

    def act(self, history, states, env: FullEnvDataset, choose_method="sample"):
        guess_score, cand_neigh_ids, _ = self.guess(history, states, env)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample":
            idx = sampler.sample()
        else:
            idx = torch.argmax(guess_score_softmax, dim=1)
        return idx, sampler.log_prob(idx), cand_neigh_ids

    def guess(self, history, states, env: FullEnvDataset):
        scan_ids, view_ids, _ = zip(*states)
        history_vec = self.language_model.encode(history)
        if Param.is_spread_token is True:
            cand_obs_vec_extend, cand_neigh_ids = self.collect_candidates(scan_ids, view_ids)
            cand_obs_vec = cand_obs_vec_extend.mean(dim=2)
        else:
            cand_obs_vec_extend, cand_neigh_ids = self.collect_candidates(scan_ids, view_ids)
            cand_obs_vec = cand_obs_vec_extend
        cand_mask, max_neigh_len = [], cand_obs_vec.shape[1]
        for cur_cand_view_ids in cand_neigh_ids:
            cand_mask.append([0] * len(cur_cand_view_ids) + [1] * (max_neigh_len - len(cur_cand_view_ids)))
        cand_mask = torch.from_numpy(np.array(cand_mask))
        history_emb = torch.mean(history_vec, dim=1)  # (batch, emb dim)
        score = torch.bmm(cand_obs_vec, history_emb.unsqueeze(2)).squeeze(2)
        score[cand_mask == 1] = -math.inf
        # score = torch.softmax(score, dim=1)  # (batch, neigh len)
        if history.shape == (Param.batch_size, 0):
            score = torch.full_like(score, -math.inf)
            score[:, 0] = 0
        return score, cand_neigh_ids, cand_obs_vec_extend  # (batch, neigh len, 6, emb dim)

    def speak(self, history, states, env: FullEnvDataset, choose_method="sample"):
        scan_ids, view_ids, angles = zip(*states)
        history_vec = self.language_model.encode(history)
        # ------ speak localization -----
        observation = env.collect_single_obs_multi_views(scan_ids, view_ids, angles).to(torch.float32)  # (batch, 36, feat len)
        if Param.is_spread_token is True:
            observe_vec = self.observe_model.encode(observation)  # (batch, 36, emb dim)
        else:
            observe_vec = self.observe_model(observation).unsqueeze(1)  # (batch, emb dim)
        cur_obs_token = self.obs_token(torch.zeros((Param.batch_size, 1)).to(torch.int64))
        context_vec = torch.cat([history_vec, cur_obs_token, observe_vec], dim=1)
        # ------ speak guidance --------
        guess_score, cand_neigh_ids, cand_obs_vec = self.guess(history, zip(scan_ids, view_ids, angles), env)
        score = torch.softmax(guess_score, dim=1)
        if Param.is_spread_token is True:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(),
                                       cand_obs_vec.view((Param.batch_size, Param.max_neigh_num + 1,
                                                          Param.max_path_img_num * Param.emb_size))).squeeze(1)
            attn_neigh_vec = attn_neigh_vec.view((Param.batch_size, Param.max_path_img_num, Param.emb_size))
        else:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(), cand_obs_vec)
        cur_neigh_obs_token = self.neigh_obs_token(torch.zeros((history_vec.shape[0], 1)).to(torch.int64))
        context_vec = torch.cat([context_vec, cur_neigh_obs_token, attn_neigh_vec], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method)
        return msg, msg_prob

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict, self.cand_act_emb_dict = {}, {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def collect_candidates(self, scan_ids, view_ids):
        neighbor_ids, neighbor_embs = [], []
        for scan_id, view_id in zip(scan_ids, view_ids):
            cur_dict = self.cand_act_emb_dict[scan_id][view_id]
            neighbor_ids.append(cur_dict["neigh_ids"])
            if Param.is_spread_token is True:
                neighbor_embs.append(cur_dict["neigh_emb"])  # (max neigh num, 6, emb dim)
            else:
                neighbor_embs.append(torch.mean(cur_dict["neigh_emb"], dim=1))  # (neigh neigh num, emb dim)
        neighbor_embs = torch.stack(neighbor_embs, dim=0)  # (batch, max neigh num, 6, emb dim)
        if Param.is_spread_token is True:
            stop_tokens = self.observe_model.stop_token(
                torch.zeros(neighbor_embs.shape[0], 1, Param.max_path_img_num).to(
                    torch.int64))  # (batch, 1, 6, emb dim)
        else:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1).to(torch.int64))
        neighbor_embs = torch.cat([stop_tokens, neighbor_embs], dim=1)
        neighbor_ids = [[cur_view_id] + cur_neigh_ids for cur_view_id, cur_neigh_ids in zip(view_ids, neighbor_ids)]
        return neighbor_embs.to(torch.float32), neighbor_ids

    def cal_loss(self, msg_a_prob, act_a_prob, rewards_loc, rewards_nav):
        shift_rewards_nav = torch.cat([rewards_nav[:, :-1], torch.zeros((Param.batch_size, 1))], dim=1)
        loss_msg = - torch.sum(msg_a_prob * (rewards_loc + shift_rewards_nav).unsqueeze(2)) / msg_a_prob.shape[2]
        loss_act = - torch.sum(act_a_prob[:, 1:] * rewards_nav[:, 1:])
        return loss_msg + loss_act


class NavMultiStepAgentA(nn.Module):
    def __init__(self):
        super(NavMultiStepAgentA, self).__init__()
        self.language_model = Language(role='A')
        self.observe_model = PosObservation(role='A')
        self.obs_token = nn.Embedding(1, Param.emb_size)
        self.neigh_obs_token = nn.Embedding(1, Param.emb_size)
        self.act_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None
        self.cand_act_emb_dict = None

    def forward(self, history, history_path_vec, states, env: FullEnvDataset, choose_method="sample"):
        # ------------ act ----------------
        scan_ids, view_ids, angles = zip(*states)
        act_idx, act_prob, cand_view_ids, cand_obs_vec = self.act(history, history_path_vec, states, env, choose_method)
        act_idx_list = act_idx.tolist()
        pred_act_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(act_idx_list, cand_view_ids)]
        new_states = zip(scan_ids, pred_act_ids, angles)
        if Param.is_spread_token is True:
            cur_path_vec = cand_obs_vec[torch.arange(0, Param.batch_size), act_idx, :, :]  # (batch, 6, emb dim) TODO check
        else:
            cur_path_vec = cand_obs_vec[torch.arange(0, Param.batch_size), act_idx, :]  # (batch, emb dim)
        new_history_path_vec = torch.cat([history_path_vec, cur_path_vec.unsqueeze(1)], dim=1) if history_path_vec is not None \
            else cur_path_vec.unsqueeze(1)  # (batch, step, 6, emb dim) or (batch, step, emb dim)
        # ----------- speak ---------------
        msg, msg_prob, next_step_score = self.speak(history, new_history_path_vec, new_states, env, choose_method)
        # ---------- is speak -------------
        is_speak = self.is_speak(next_step_score)
        if history.shape[1] == 0 or Param.is_all_speak == True:
            is_speak = torch.full_like(is_speak, True)
        msg[is_speak == False, :] = Param.tokens["<msg A>"]["start pos"]
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg A>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, new_history_path_vec, msg, msg_prob, is_speak, pred_act_ids, act_prob

    def act(self, history, history_path_vec, states, env: FullEnvDataset, choose_method="sample"):
        guess_score, cand_view_ids, cand_obs_vec = self.guess(history, history_path_vec, states, env)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample":
            idx = sampler.sample()
        else:
            idx = torch.argmax(guess_score_softmax, dim=1)
        return idx, sampler.log_prob(idx), cand_view_ids, cand_obs_vec

    def guess(self, history, history_path_vec, states, env: FullEnvDataset):
        scan_ids, view_ids, _ = zip(*states)
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_anti_mask = torch.full_like(history, 1)
        history_anti_mask[history_mask == 1] = 0
        history_vec = self.language_model.encode(history, mask=history_mask)
        history_emb = torch.sum(history_vec * history_anti_mask.unsqueeze(2), dim=1)/torch.sum(history_anti_mask, dim=1).unsqueeze(1)  # (batch, emb dim)
        if Param.is_spread_token is True:
            cand_obs_vec_extend, cand_view_ids = self.collect_candidates(scan_ids, view_ids)  # (batch, 6, emb dim)
            cand_obs_vec = cand_obs_vec_extend.mean(dim=2)
            history_path_emb = torch.mean(torch.mean(history_path_vec, dim=2), dim=1) if history_path_vec is not None else None
        else:
            cand_obs_vec_extend, cand_view_ids = self.collect_candidates(scan_ids, view_ids)  # (batch, emb dim)
            cand_obs_vec = cand_obs_vec_extend
            history_path_emb = torch.mean(history_path_vec, dim=1) if history_path_vec is not None else None  # (batch, emb dim)
        cand_mask, max_neigh_len = [], cand_obs_vec.shape[1]
        for cur_cand_view_ids in cand_view_ids:
            cand_mask.append([0] * len(cur_cand_view_ids) + [1] * (max_neigh_len - len(cur_cand_view_ids)))
        cand_mask = torch.from_numpy(np.array(cand_mask))
        context_emb = history_emb - history_path_emb if history_path_emb is not None else history_emb  # (batch, emb dim)
        score = torch.bmm(cand_obs_vec, context_emb.unsqueeze(2)).squeeze(2)
        score[cand_mask == 1] = -math.inf
        if history.shape == (Param.batch_size, 0):
            score = torch.full_like(score, -math.inf)
            score[:, 0] = 0
        return score, cand_view_ids, cand_obs_vec_extend

    def speak(self, history, history_path_vec, states, env: FullEnvDataset, choose_method="sample"):
        scan_ids, view_ids, angles = zip(*states)
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_vec = self.language_model.encode(history, mask=history_mask)
        # --------- speak localization ----------
        observation = env.collect_single_obs_multi_views(scan_ids, view_ids, angles).to(
            torch.float32)  # (batch, 36, feat len)
        if Param.is_spread_token is True:
            observe_vec = self.observe_model.encode(observation)  # (batch, 36, emb dim)
        else:
            observe_vec = self.observe_model(observation).unsqueeze(1)  # (batch, emb dim)
        cur_obs_token = self.obs_token(torch.zeros((Param.batch_size, 1)).to(torch.int64))
        context_vec = torch.cat([history_vec, cur_obs_token, observe_vec], dim=1)
        history_mask = torch.cat([history_mask, torch.zeros((history_mask.shape[0], 1 + observe_vec.shape[1]))], dim=1)
        # -------- speak guidance -------------
        states = zip(scan_ids, view_ids, angles)
        guess_score, cand_view_ids, cand_obs_vec = self.guess(history, history_path_vec, states, env)
        score = torch.softmax(guess_score, dim=1)
        history_path_emb = torch.mean(history_path_vec, dim=1) if history_path_vec is not None else None  # (batch, 6, emb dim) or (batch, emb dim)
        cand_path_vec = cand_obs_vec + history_path_emb.unsqueeze(
            dim=1) if history_path_emb is not None else cand_obs_vec  # (batch, neigh len, 6, emb dim) or (batch, neigh len, emb dim)
        if Param.is_spread_token is True:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(),
                                       cand_path_vec.view((Param.batch_size, Param.max_neigh_num + 1,
                                                           Param.max_path_img_num * Param.emb_size))).squeeze(1)  # (batch, 6 * emb dim)
            attn_neigh_vec = attn_neigh_vec.view((Param.batch_size, Param.max_path_img_num, Param.emb_size))
        else:
            attn_neigh_vec = torch.bmm(score.unsqueeze(1).detach(), cand_path_vec)
        cur_neigh_obs_token = self.neigh_obs_token(torch.zeros((history_vec.shape[0], 1)).to(torch.int64))
        context_vec = torch.cat([context_vec, cur_neigh_obs_token, attn_neigh_vec], dim=1)
        history_mask = torch.cat([history_mask, torch.zeros((history_mask.shape[0], 1 + attn_neigh_vec.shape[1]))], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method, mem_mask=history_mask)
        return msg, msg_prob, score

    def is_speak(self, guess_score):
        max_score_idx = torch.argmax(guess_score, dim=1)
        max_score = guess_score[torch.arange(0, Param.batch_size), max_score_idx]
        guess_score_copy = guess_score.clone()
        guess_score_copy[torch.arange(0, Param.batch_size), max_score_idx] = 0
        sec_max_score = torch.max(guess_score_copy, dim=1).values
        is_speak = ((max_score - sec_max_score) < 0.99)
        return is_speak

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict, self.cand_act_emb_dict = {}, {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            self.cand_act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64))  # (neigh len, 6, emb dim)
                self.cand_act_emb_dict[scan_id][view_id] = {"neigh_ids": neigh_ids,
                                                            "neigh_emb": torch.cat([cur_obs_emb, torch.zeros((Param.max_neigh_num - cur_obs_emb.shape[0], Param.max_path_img_num, Param.emb_size))], dim=0)}
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb.tolist()):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def collect_candidates(self, scan_ids, view_ids):
        neighbor_ids, neighbor_embs = [], []
        for scan_id, view_id in zip(scan_ids, view_ids):
            cur_dict = self.cand_act_emb_dict[scan_id][view_id]
            neighbor_ids.append(cur_dict["neigh_ids"])
            if Param.is_spread_token is True:
                neighbor_embs.append(cur_dict["neigh_emb"])  # (max neigh num, 6, emb dim)
            else:
                neighbor_embs.append(torch.mean(cur_dict["neigh_emb"], dim=1))  # (neigh neigh num, emb dim)
        neighbor_embs = torch.stack(neighbor_embs, dim=0) # (batch, max neigh num, 6, emb dim)
        if Param.is_spread_token is True:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1, Param.max_path_img_num).to(torch.int64))  # (batch, 1, 6, emb dim)
        else:
            stop_tokens = self.observe_model.stop_token(torch.zeros(neighbor_embs.shape[0], 1).to(torch.int64))
        neighbor_embs = torch.cat([stop_tokens, neighbor_embs], dim=1)
        neighbor_ids = [[cur_view_id] + cur_neigh_ids for cur_view_id, cur_neigh_ids in zip(view_ids, neighbor_ids)]
        return neighbor_embs.to(torch.float32), neighbor_ids

    def cal_loss(self, msg_a_prob, guess_a_prob, rewards_msg_a, rewards_nav):
        loss_msg = - torch.sum(msg_a_prob * rewards_msg_a.unsqueeze(2)) / msg_a_prob.shape[2]
        loss_act = - torch.sum(guess_a_prob * rewards_nav)
        return loss_act + loss_msg
