import torch
from torch import nn
from Language import *
from Observation import *
from Environment import *

# ====== LOCALIZATION ========

class PosAgentB(nn.Module):
    def __init__(self):
        super(PosAgentB, self).__init__()
        self.language_model = Language(role="B")
        self.observe_model = PosObservation(role="B")
        self.obs_token = nn.Embedding(1, Param.emb_size)
        self.obs_emb = None
        self.scan_ids, self.scan_ids_dict = None, None
        self.view_ids, self.view_ids_dict = None, None

    def forward(self, history, view_ids, viewpoint_vec, choose_method="sample"):
        # -------- guess ----------
        guess_score = self.guess(history, view_ids, viewpoint_vec)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample":
            guess_idx = sampler.sample()
        else:
            guess_idx = torch.argmax(guess_score_softmax, dim=1)
        # --------- speak ----------
        msg, msg_prob = self.speak(history, guess_score_softmax, viewpoint_vec, choose_method=choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg B>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, guess_idx, sampler.log_prob(guess_idx), msg, msg_prob

    def guess(self, history, view_ids, viewpoint_vec):
        """
        :param history: (batch, history len, emb dim)
        :param viewpoints: (batch, max cand len, 36, feat len)
        :param viewpoint_vec: (batch, max node num, 36, emb dim)
        :param view_ids: (batch, max cand len)
        :return:
        """
        history_vec = self.language_model.encode(history).mean(dim=1)  # (batch, emb dim)
        viewpoint_vec_mean = torch.mean(viewpoint_vec, dim=2)
        scores = torch.bmm(viewpoint_vec_mean, history_vec.unsqueeze(2)).squeeze(2)
        mask = [[0] * len(cur_view_ids) + [1] * (Param.max_num_node - len(cur_view_ids)) for cur_view_ids in view_ids]
        mask = torch.from_numpy(np.array(mask))
        scores[mask == 1] = -math.inf
        return scores

    def speak(self, history, guess_score, viewpoint_vec, choose_method="sample"):
        """
        :param history: (batch, history len)
        :param viewpoints: (batch, max cand len, 36, feat len)
        :param viewpoint_vec: (batch, max node num, 36, emb dim)
        :param guess_score: (batch, max cand len)
        :param choose_method: sample or greedy
        :return:
        """
        history_vec = self.language_model.encode(history)  # (batch, history len, emb dim)
        if Param.is_spread_token is True:
            attn_view_vec = torch.bmm(guess_score.unsqueeze(1), viewpoint_vec.view((Param.batch_size, Param.max_num_node,
                                                                                    36 * Param.emb_size)))
            attn_view_vec = attn_view_vec.view((Param.batch_size, 36, Param.emb_size))
        else:
            viewpoint_vec_mean = torch.mean(viewpoint_vec, dim=2)  # (batch, max node num, emb dim)
            attn_view_vec = torch.bmm(guess_score.unsqueeze(1), viewpoint_vec_mean)  # (batch, 1, emb dim)
        cur_obs_token = self.obs_token(torch.zeros(history.shape[0]).to(torch.int64)).unsqueeze(1)  # (batch, 1, emb dim)
        context_vec = torch.cat([history_vec, cur_obs_token, attn_view_vec], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method=choose_method)
        return msg, msg_prob

    def load_obs_emb(self, train_env: FullEnvDataset, eval_env:FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.scan_ids = train_scan_ids + eval_scan_ids
        train_view_ids = train_env.get_view_ids(train_scan_ids)
        eval_view_ids = eval_env.get_view_ids(eval_scan_ids)
        self.view_ids = train_view_ids + eval_view_ids
        train_viewpoints = train_env.collect_obs(train_scan_ids, train_view_ids)  # (all scan num, max node num, 36, 1000)
        eval_viewpoints = eval_env.collect_obs(eval_scan_ids, eval_view_ids)
        viewpoints = torch.cat([train_viewpoints, eval_viewpoints], dim=0)
        self.obs_emb = self.observe_model.encode(viewpoints.view((viewpoints.shape[0] * viewpoints.shape[1], 36, 1000)))  # (all scan num, max node num. 36, 128)
        self.obs_emb = self.obs_emb.view((viewpoints.shape[0], viewpoints.shape[1], 36, Param.emb_size))
        self.scan_ids_dict = {cur_scan_id: i for i, cur_scan_id in enumerate(self.scan_ids)}
        self.view_ids_dict = [{cur_view_id: i for i, cur_view_id in enumerate(cur_batch_view_ids)}
                              for cur_batch_view_ids in self.view_ids]

    def cal_loss_alg_blu(self, msg_b_prob, guess_b_prob, rewards):
        loss_msg = torch.sum(msg_b_prob * torch.zeros_like(msg_b_prob))
        loss_guess = - torch.sum(guess_b_prob * rewards)
        return loss_msg + loss_guess

    def cal_loss(self, msg_b_prob, guess_b_prob, rewards):
        loss_msg = - torch.sum(msg_b_prob * rewards.unsqueeze(2)) / msg_b_prob.shape[2]
        loss_guess = - torch.sum(guess_b_prob * rewards)
        return loss_msg + loss_guess

# ======= GUIDANCE ==========

class GuideAgentB(nn.Module):
    def __init__(self):
        super(GuideAgentB, self).__init__()
        self.language_model = Language(role='B')
        self.observe_model = PosObservation(role='B')
        self.plan_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None

    def forward(self, history, env: FullEnvDataset, scan_ids, end_ids, view_ids, true_loc, choose_method="sample"):
        msg, msg_prob = self.speak(history, env, scan_ids, end_ids, view_ids, true_loc, choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg B>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, msg, msg_prob

    def speak(self, history, env: FullEnvDataset, scan_ids, end_ids, view_ids, true_loc, choose_method="sample"):
        history_vec = self.language_model.encode(history)
        cur_plan_token = self.plan_token(torch.zeros(history.shape[0], 1).to(torch.int64))
        next_step_info = get_next_step_ids_full(env.graph_paths, scan_ids, view_ids, end_ids)
        next_step_vec = self.get_path_vec(scan_ids, view_ids, next_step_info)
        if Param.is_spread_token is True:
            attn_step_vec = torch.bmm(true_loc.unsqueeze(1),
                                      next_step_vec.view((Param.batch_size, Param.max_num_node,
                                                          Param.max_path_img_num * Param.emb_size))).squeeze(1)
            attn_step_vec = attn_step_vec.view((Param.batch_size, Param.max_path_img_num, Param.emb_size))
        else:
            attn_step_vec = torch.bmm(true_loc.unsqueeze(1), next_step_vec)
        context_vec = torch.cat([history_vec, cur_plan_token, attn_step_vec], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method=choose_method)
        return msg, msg_prob

    def get_path_vec(self, scan_ids, cand_ids, paths):
        obs_emb = []
        for i, (path, scan_id, cand_view_id) in enumerate(zip(paths, scan_ids, cand_ids)):  # batch
            obs_emb.append([])
            for j, (cur_path, cur_view_id) in enumerate(zip(path, cand_view_id)): # cand len
                if cur_path is not None:
                    cur_neigh_id = cur_path["next"]
                    if Param.is_spread_token is True:
                        obs_emb[-1].append(self.act_emb_dict[scan_id][(cur_view_id, cur_neigh_id)])  # (6, emb dim)
                    else:
                        obs_emb[-1].append(torch.mean(self.act_emb_dict[scan_id][(cur_view_id, cur_neigh_id)], dim=0))
                else:
                    if Param.is_spread_token is True:
                        obs_emb[-1].append(self.observe_model.stop_token(torch.zeros((6, )).to(torch.int64)))
                    else:
                        obs_emb[-1].append(self.observe_model.stop_token(torch.zeros((1, )).to(torch.int64)).squeeze())
            obs_emb[-1] = torch.stack(obs_emb[-1], dim=0)  # (cand len, 6, emb dim)
            obs_emb[-1] = torch.cat([obs_emb[-1], torch.zeros((Param.max_num_node - obs_emb[-1].shape[0],
                                                               Param.max_path_img_num, Param.emb_size))], dim=0)
        obs_emb = torch.stack(obs_emb, dim=0)  # (batch, max node num, 6, emb dim)
        return obs_emb.to(torch.float32)

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict = {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def cal_loss_alu_blg(self, msg_b_prob, rewards):
        loss_msg = - torch.sum(msg_b_prob * rewards.unsqueeze(2)) / msg_b_prob.shape[1]
        return loss_msg

    def cal_loss(self, msg_b_prob, rewards):
        loss_msg = - torch.sum(msg_b_prob * rewards.unsqueeze(2)) / msg_b_prob.shape[1]
        return loss_msg

class GuideMultiStepAgentB(nn.Module):
    def __init__(self):
        super(GuideMultiStepAgentB, self).__init__()
        self.language_model = Language(role='B')
        self.observe_model = PosObservation(role='B')
        self.plan_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None

    def forward(self, history, env: FullEnvDataset, scan_ids, end_ids, view_ids, is_speak, true_loc, choose_method="sample", is_true_loc=False):
        msg, msg_prob = self.speak(history, env, scan_ids, end_ids, view_ids, true_loc, choose_method, is_true_loc=is_true_loc)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg B>"]["pos"])
        if is_speak is not None:
            msg[is_speak == False, :] = Param.tokens["<msg A>"]["start pos"]
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, msg, msg_prob

    def speak(self, history, env: FullEnvDataset, scan_ids, end_ids, view_ids, true_loc, choose_method="sample", is_true_loc=False):
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_vec = self.language_model.encode(history, mask=history_mask)
        cur_plan_token = self.plan_token(torch.zeros(history.shape[0], 1).to(torch.int64))
        # next_step_info = get_next_step_ids_full(env.graph_paths, scan_ids, view_ids, end_ids)
        true_loc_list = torch.argmax(true_loc, dim=1).tolist()
        true_view_ids = [[cur_view_ids[i]] for cur_view_ids, i in zip(view_ids, true_loc_list)]
        next_route_info = get_route_ids(env.graph_paths, scan_ids, true_view_ids, end_ids)
        route_vec = self.get_route_vec(scan_ids, true_view_ids, next_route_info, is_true_loc=is_true_loc)
        if is_true_loc is False:
            if Param.is_spread_token is True:
                attn_route_vec = torch.bmm(true_loc.unsqueeze(1),
                                           route_vec.view((Param.batch_size, Param.max_num_node,
                                                           Param.max_turns * Param.max_path_img_num * Param.emb_size))).squeeze(1)
                attn_route_vec = attn_route_vec.view((Param.batch_size, Param.max_turns * Param.max_path_img_num, Param.emb_size))
            else:
                attn_route_vec = torch.bmm(true_loc.unsqueeze(1),
                                           route_vec.view((Param.batch_size, Param.max_num_node,
                                                           Param.max_turns * Param.emb_size))).squeeze(1)
                attn_route_vec = attn_route_vec.view((Param.batch_size, Param.max_turns, Param.emb_size))
        else:
            if Param.is_spread_token is True:
                attn_route_vec = route_vec.squeeze(1).view((Param.batch_size, Param.max_turns * Param.max_path_img_num, Param.emb_size))
            else:
                attn_route_vec = route_vec.squeeze(1)
        context_vec = torch.cat([history_vec, cur_plan_token, attn_route_vec], dim=1)
        history_mask = torch.cat([history_mask, torch.zeros((Param.batch_size, 1 + attn_route_vec.shape[1]))], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method=choose_method, mem_mask=history_mask)
        return msg, msg_prob

    def get_route_vec(self, scan_ids, cand_ids, paths, is_true_loc=False):
        """
        :param scan_ids:
        :param cand_ids:
        :param paths:
        :return:
        """
        obs_emb = []
        for i, (cur_routes, scan_id, cand_view_id) in enumerate(zip(paths, scan_ids, cand_ids)):
            obs_emb.append([])
            for j, (cur_route, cur_view_id) in enumerate(zip(cur_routes, cand_view_id)):
                obs_emb[-1].append([])
                last_id = cur_view_id
                for k, cur_step in enumerate(cur_route):
                    if cur_step is not None and last_id != cur_step["next"]:
                        cur_neigh_id = cur_step["next"]
                        if Param.is_spread_token is True:
                            obs_emb[-1][-1].append(self.act_emb_dict[scan_id][(last_id, cur_neigh_id)])
                        else:
                            obs_emb[-1][-1].append(torch.mean(self.act_emb_dict[scan_id][(last_id, cur_neigh_id)], dim=0))
                        last_id = cur_neigh_id
                    else:
                        if Param.is_spread_token is True:
                            obs_emb[-1][-1].append(self.observe_model.stop_token(torch.zeros((6,)).to(torch.int64)))
                        else:
                            obs_emb[-1][-1].append(self.observe_model.stop_token(torch.zeros((1,)).to(torch.int64)).squeeze())
                obs_emb[-1][-1] = torch.stack(obs_emb[-1][-1], dim=0)  # (max turn, 6, emb dim)
            obs_emb[-1] = torch.stack(obs_emb[-1], dim=0)  # (cand len, max turn, 6, emb dim)
            if is_true_loc is False:
                obs_emb[-1] = torch.cat([obs_emb[-1], torch.zeros((Param.max_num_node - obs_emb[-1].shape[0], Param.max_turns,
                                                                   Param.max_path_img_num, Param.emb_size))])
        obs_emb = torch.stack(obs_emb)  # (batch, cand len, max turn, 6, emb dim)
        return obs_emb.to(torch.float32)

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict = {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def cal_loss(self, msg_b_prob, rewards_msg_b):
        loss_msg = - torch.sum(msg_b_prob * rewards_msg_b.unsqueeze(2)) / msg_b_prob.shape[2]
        return loss_msg

# ======= Navigation ==========

class NavAgentB(nn.Module):
    def __init__(self):
        super(NavAgentB, self).__init__()
        self.language_model = Language(role="B")
        self.observe_model = PosObservation(role="B")
        self.obs_token = nn.Embedding(1, Param.emb_size)
        self.plan_token = nn.Embedding(1, Param.emb_size)
        self.act_emb_dict = None
        self.obs_emb = None
        self.scan_ids, self.scan_ids_dict = None, None
        self.view_ids, self.view_ids_dict = None, None

    def forward(self, history, viewpoint_vec, env, scan_ids, end_ids, view_ids, true_loc=None, choose_method="sample"):
        # ------ guess --------
        guess_score = self.guess(history, view_ids, viewpoint_vec)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample": guess_idx = sampler.sample()
        else: guess_idx = torch.argmax(guess_score_softmax, dim=1)
        # ------ speak --------
        msg, msg_prob = self.speak(history, viewpoint_vec, guess_score_softmax, env, scan_ids, end_ids, view_ids,
                                   true_loc=true_loc, choose_method=choose_method)
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg B>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, guess_idx, sampler.log_prob(guess_idx), msg, msg_prob

    def guess(self, history, view_ids, viewpoint_vec):
        history_vec = self.language_model.encode(history).mean(dim=1)  # (batch, emb dim)
        viewpoint_vec_mean = torch.mean(viewpoint_vec, dim=2)
        scores = torch.bmm(viewpoint_vec_mean, history_vec.unsqueeze(2)).squeeze(2)
        mask = [[0] * len(cur_view_ids) + [1] * (Param.max_num_node - len(cur_view_ids)) for cur_view_ids in view_ids]
        mask = torch.from_numpy(np.array(mask))
        scores[mask == 1] = -math.inf
        return scores

    def speak(self, history, viewpoint_vec, guess_score, env: FullEnvDataset, scan_ids, end_ids, view_ids,
              true_loc=None, choose_method="sample"):
        history_vec = self.language_model.encode(history)  # (batch, history_len, emb dim)
        # --------- speak localization -----------
        if Param.is_spread_token is True:
            attn_view_vec = torch.bmm(guess_score.unsqueeze(1), viewpoint_vec.view((Param.batch_size, Param.max_num_node,
                                                                                    36 * Param.emb_size))).squeeze(1)
            attn_view_vec = attn_view_vec.view((Param.batch_size, 36, Param.emb_size))
        else:
            attn_view_vec = torch.bmm(guess_score.unsqueeze(1), viewpoint_vec)  # (batch, 1, emb dim)
        cur_obs_token = self.obs_token(torch.zeros(viewpoint_vec.shape[0]).to(torch.int64)).unsqueeze(
            1)  # (batch, 1, emb dim)
        context_vec = torch.cat([history_vec, cur_obs_token, attn_view_vec], dim=1)
        # --------- speak guidance --------------
        cur_plan_token = self.plan_token(torch.zeros(history.shape[0], 1).to(torch.int64))
        next_step_info = get_next_step_ids_full(env.graph_paths, scan_ids, view_ids, end_ids)
        next_step_vec = self.get_path_vec(scan_ids, view_ids, next_step_info)
        if true_loc is None: true_loc = guess_score
        if Param.is_spread_token is True:
            attn_step_vec = torch.bmm(true_loc.unsqueeze(1),
                                      next_step_vec.view((Param.batch_size, Param.max_num_node,
                                                          Param.max_path_img_num * Param.emb_size))).squeeze(1)
            attn_step_vec = attn_step_vec.view((Param.batch_size, Param.max_path_img_num, Param.emb_size))
        else:
            attn_step_vec = torch.bmm(true_loc.unsqueeze(1), next_step_vec)
        context_vec = torch.cat([context_vec, cur_plan_token, attn_step_vec], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method=choose_method)
        return msg, msg_prob

    def get_path_vec(self, scan_ids, cand_ids, paths):
        obs_emb = []
        for i, (path, scan_id, cand_view_id) in enumerate(zip(paths, scan_ids, cand_ids)):  # batch
            obs_emb.append([])
            for j, (cur_path, cur_view_id) in enumerate(zip(path, cand_view_id)):  # cand len
                if cur_path is not None:
                    cur_neigh_id = cur_path["next"]
                    if Param.is_spread_token is True:
                        obs_emb[-1].append(self.act_emb_dict[scan_id][(cur_view_id, cur_neigh_id)])  # (6, emb dim)
                    else:
                        obs_emb[-1].append(torch.mean(self.act_emb_dict[scan_id][(cur_view_id, cur_neigh_id)], dim=0))
                else:
                    if Param.is_spread_token is True:
                        obs_emb[-1].append(self.observe_model.stop_token(torch.zeros((6,)).to(torch.int64)))
                    else:
                        obs_emb[-1].append(self.observe_model.stop_token(torch.zeros((1,)).to(torch.int64)).squeeze())
            obs_emb[-1] = torch.stack(obs_emb[-1], dim=0)  # (cand len, 6, emb dim)
            obs_emb[-1] = torch.cat([obs_emb[-1], torch.zeros((Param.max_num_node - obs_emb[-1].shape[0],
                                                               Param.max_path_img_num, Param.emb_size))], dim=0)
        obs_emb = torch.stack(obs_emb, dim=0)  # (batch, max node num, 6, emb dim)
        return obs_emb.to(torch.float32)

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict = {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx, dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32), cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(np.array(cur_neigh_emb))  # (6, emb dim)

    def load_obs_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.scan_ids = train_scan_ids + eval_scan_ids
        train_view_ids = train_env.get_view_ids(train_scan_ids)
        eval_view_ids = eval_env.get_view_ids(eval_scan_ids)
        self.view_ids = train_view_ids + eval_view_ids
        train_viewpoints = train_env.collect_obs(train_scan_ids,
                                                 train_view_ids)  # (all scan num, max node num, 36, 1000)
        eval_viewpoints = eval_env.collect_obs(eval_scan_ids, eval_view_ids)
        viewpoints = torch.cat([train_viewpoints, eval_viewpoints], dim=0)
        self.obs_emb = self.observe_model.encode(viewpoints.view(
            (viewpoints.shape[0] * viewpoints.shape[1], 36, 1000)))  # (all scan num, max node num. 36, 128)
        self.obs_emb = self.obs_emb.view((viewpoints.shape[0], viewpoints.shape[1], 36, Param.emb_size))
        self.scan_ids_dict = {cur_scan_id: i for i, cur_scan_id in enumerate(self.scan_ids)}
        self.view_ids_dict = [{cur_view_id: i for i, cur_view_id in enumerate(cur_batch_view_ids)}
                              for cur_batch_view_ids in self.view_ids]

    def cal_loss(self, msg_b_prob, guess_b_prob, rewards_loc, rewards_nav):
        shift_rewards_nav = torch.cat([rewards_nav[:, 1:], torch.zeros((Param.batch_size, 1))], dim=1)
        shift_rewards_loc = torch.cat([rewards_loc[:, 1:], torch.zeros((Param.batch_size, 1))], dim=1)
        loss_msg = - torch.sum(msg_b_prob * (shift_rewards_loc + shift_rewards_nav).unsqueeze(2)) / msg_b_prob.shape[1]
        loss_guess = - torch.sum(guess_b_prob * rewards_loc)
        return loss_msg + loss_guess


class NavMultiStepAgentB(nn.Module):
    def __init__(self):
        super(NavMultiStepAgentB, self).__init__()
        self.language_model = Language(role='B')
        self.observe_model = PosObservation(role='B')
        self.obs_token = nn.Embedding(1, Param.emb_size)
        self.plan_token = nn.Embedding(1, Param.emb_size)
        self.obs_emb = None
        self.scan_ids, self.scan_ids_dict = None, None
        self.view_ids, self.view_ids_dict = None, None
        self.act_emb_dict = None

    def forward(self, history, viewpoint_vec, env: FullEnvDataset, scan_ids, end_ids, view_ids, is_speak,
                true_loc=None, choose_method="sample", is_true_loc=False):
        # -------- guess -----------
        guess_score = self.guess(history, view_ids, viewpoint_vec)
        guess_score_softmax = torch.softmax(guess_score, dim=1)
        sampler = Categorical(guess_score_softmax)
        if choose_method == "sample": guess_idx = sampler.sample()
        else: guess_idx = torch.argmax(guess_score_softmax, dim=1)
        # -------- speak -----------
        msg, msg_prob = self.speak(history, viewpoint_vec, guess_score_softmax, env, scan_ids, end_ids, view_ids,
                                    true_loc, choose_method="sample", is_true_loc=is_true_loc)
        # ----- is speak -----
        msg[is_speak == False, :] = Param.tokens["<msg A>"]["start pos"]
        sep_token = torch.full((history.shape[0], 1), Param.tokens["<msg B>"]["pos"])
        new_history = torch.cat([history, sep_token, msg], dim=1)
        return new_history, guess_idx, sampler.log_prob(guess_idx), msg, msg_prob

    def guess(self, history, view_ids, viewpoint_vec):
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_anti_mask = torch.full_like(history, 1)
        history_anti_mask[history_mask == 1] = 0
        history_vec = self.language_model.encode(history, mask=history_mask)
        history_emb = torch.sum(history_vec * history_anti_mask.unsqueeze(2), dim=1) / torch.sum(history_anti_mask, dim=1).unsqueeze(1)  # (batch, emb dim)
        viewpoint_vec_mean = torch.mean(viewpoint_vec, dim=2) # (batch, max node num, 36, emb dim)
        scores = torch.bmm(viewpoint_vec_mean, history_emb.unsqueeze(2)).squeeze(2)
        mask = [[0] * len(cur_view_ids) + [1] * (Param.max_num_node - len(cur_view_ids)) for cur_view_ids in view_ids]
        mask = torch.from_numpy(np.array(mask))
        scores[mask == 1] = -math.inf
        return scores

    def speak(self, history, viewpoint_vec, guess_score, env: FullEnvDataset, scan_ids, end_ids, view_ids,
              true_loc=None, choose_method="sample", is_true_loc=False):
        history_mask = torch.full_like(history, 0)
        history_mask[history == Param.tokens["<msg A>"]["start pos"]] = 1
        history_vec = self.language_model.encode(history, mask=history_mask)
        # --------- speak localization -----------
        if Param.is_spread_token is True:
            attn_view_vec = torch.bmm(guess_score.unsqueeze(1), viewpoint_vec.view((Param.batch_size, Param.max_num_node,
                                                                                    36 * Param.emb_size))).squeeze(1)
            attn_view_vec = attn_view_vec.view((Param.batch_size, 36, Param.emb_size))
        else:
            attn_view_vec = torch.bmm(guess_score.unsqueeze(1), viewpoint_vec)  # (batch, 1, emb dim)
        cur_obs_token = self.obs_token(torch.zeros(viewpoint_vec.shape[0]).to(torch.int64)).unsqueeze(1)  # (batch, 1, emb dim)
        context_vec = torch.cat([history_vec, cur_obs_token, attn_view_vec], dim=1)
        history_mask = torch.cat([history_mask, torch.zeros((history_mask.shape[0], 1 + attn_view_vec.shape[1]))], dim=1)
        # -------- speak guidance -------------
        cur_plan_token = self.plan_token(torch.zeros(history.shape[0], 1).to(torch.int64))
        if is_true_loc is True:
            true_loc_list = torch.argmax(true_loc, dim=1).tolist()
            true_view_ids = [[cur_view_ids[i]] for cur_view_ids, i in zip(view_ids, true_loc_list)]
            next_step_info = get_route_ids(env.graph_paths, scan_ids, true_view_ids, end_ids)
            route_vec = self.get_route_vec(scan_ids, true_view_ids, next_step_info, is_true_loc=is_true_loc)
        else:
            next_step_info = get_route_ids(env.graph_paths, scan_ids, view_ids, end_ids)
            route_vec = self.get_route_vec(scan_ids, view_ids, next_step_info, is_true_loc=is_true_loc)
        if true_loc is None: true_loc = guess_score.detach()  # (batch, cand len)
        if is_true_loc is False:
            if Param.is_spread_token is True:
                attn_route_vec = torch.bmm(true_loc.unsqueeze(1),
                                           route_vec.view((Param.batch_size, Param.max_num_node,
                                                           Param.max_turns * Param.max_path_img_num * Param.emb_size))).squeeze(1)
                attn_route_vec = attn_route_vec.view((Param.batch_size, Param.max_turns * Param.max_path_img_num, Param.emb_size))
            else:
                attn_route_vec = torch.bmm(true_loc.unsqueeze(1),
                                           route_vec.view((Param.batch_size, Param.max_num_node,
                                                           Param.max_turns * Param.emb_size))).squeeze(1)
                attn_route_vec = attn_route_vec.view((Param.batch_size, Param.max_turns, Param.emb_size))
        else:
            if Param.is_spread_token is True:
                attn_route_vec = route_vec.squeeze(1).view((Param.batch_size, Param.max_turns * Param.max_path_img_num, Param.emb_size))
            else:
                attn_route_vec = route_vec.squeeze(1)
        context_vec = torch.cat([context_vec, cur_plan_token, attn_route_vec], dim=1)
        history_mask = torch.cat([history_mask, torch.zeros((history_mask.shape[0], 1 + attn_route_vec.shape[1]))], dim=1)
        msg, msg_prob = self.language_model.decode_msg(context_vec, choose_method=choose_method, mem_mask=history_mask)
        return msg, msg_prob

    def get_route_vec(self, scan_ids, cand_ids, paths, is_true_loc=False):
        """
        :param scan_ids:
        :param cand_ids:
        :param paths:
        :return:
        """
        obs_emb = []
        for i, (cur_routes, scan_id, cand_view_id) in enumerate(zip(paths, scan_ids, cand_ids)):
            obs_emb.append([])
            for j, (cur_route, cur_view_id) in enumerate(zip(cur_routes, cand_view_id)):
                obs_emb[-1].append([])
                last_id = cur_view_id
                for k, cur_step in enumerate(cur_route):
                    if cur_step is not None and last_id != cur_step["next"]:
                        cur_neigh_id = cur_step["next"]
                        if Param.is_spread_token is True:
                            obs_emb[-1][-1].append(self.act_emb_dict[scan_id][(last_id, cur_neigh_id)])
                        else:
                            obs_emb[-1][-1].append(
                                torch.mean(self.act_emb_dict[scan_id][(last_id, cur_neigh_id)], dim=0))
                        last_id = cur_neigh_id
                    else:
                        if Param.is_spread_token is True:
                            obs_emb[-1][-1].append(self.observe_model.stop_token(torch.zeros((6,)).to(torch.int64)))
                        else:
                            obs_emb[-1][-1].append(
                                self.observe_model.stop_token(torch.zeros((1,)).to(torch.int64)).squeeze())
                obs_emb[-1][-1] = torch.stack(obs_emb[-1][-1], dim=0)  # (max turn, 6, emb dim)
            obs_emb[-1] = torch.stack(obs_emb[-1], dim=0)  # (cand len, max turn, 6, emb dim)
            if is_true_loc is False:
                obs_emb[-1] = torch.cat(
                    [obs_emb[-1], torch.zeros((Param.max_num_node - obs_emb[-1].shape[0], Param.max_turns,
                                               Param.max_path_img_num, Param.emb_size))])
        obs_emb = torch.stack(obs_emb)  # (batch, cand len, max turn, 6, emb dim)
        return obs_emb.to(torch.float32)

    def load_act_emb(self, train_env: FullEnvDataset, eval_env: FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.act_emb_dict = {}
        # ---------- train ----------------
        for scan_id in train_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in train_env.graphs: train_env.graphs[scan_id] = train_env.load_graph(scan_id)[scan_id]
            for view_id in train_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(train_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in train_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(train_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(
                        np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx,
                                                                                dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32),
                                                        cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(
                        np.array(cur_neigh_emb))  # (6, emb dim)
        # ------------ eval ----------------
        for scan_id in eval_scan_ids:
            self.act_emb_dict[scan_id] = {}
            if scan_id not in eval_env.graphs: eval_env.graphs[scan_id] = eval_env.load_graph(scan_id)[scan_id]
            for view_id in eval_env.graphs[scan_id]:
                cur_obs, cur_obs_idx = [], []
                neigh_ids = list(eval_env.graphs[scan_id][view_id].keys())
                for neigh_id in neigh_ids:
                    cur_obs.append([]); cur_obs_idx.append([])
                    for ix, h, e, d in eval_env.graphs[scan_id][view_id][neigh_id]:
                        cur_obs[-1].append(torch.from_numpy(eval_env.features[scan_id][view_id][(h, e)]))
                        cur_obs_idx[-1].append(h // 30 * 3 + e // 30 + 1)
                    cur_obs[-1], cur_obs_idx[-1] = torch.stack(cur_obs[-1], dim=0), torch.from_numpy(
                        np.array(cur_obs_idx[-1]))  # (6, 1000)
                cur_obs, cur_obs_idx = torch.stack(cur_obs, dim=0), torch.stack(cur_obs_idx,
                                                                                dim=0)  # (neigh len, 6, 1000)
                cur_obs_emb = self.observe_model.encode(cur_obs.to(torch.float32),
                                                        cur_obs_idx.to(torch.int64)).tolist()  # (neigh len, 6, emb dim)
                for neigh_id, cur_neigh_emb in zip(neigh_ids, cur_obs_emb):
                    self.act_emb_dict[scan_id][(view_id, neigh_id)] = torch.from_numpy(
                        np.array(cur_neigh_emb))  # (6, emb dim)

    def load_obs_emb(self, train_env: FullEnvDataset, eval_env:FullEnvDataset, train_scan_ids, eval_scan_ids):
        self.scan_ids = train_scan_ids + eval_scan_ids
        train_view_ids = train_env.get_view_ids(train_scan_ids)
        eval_view_ids = eval_env.get_view_ids(eval_scan_ids)
        self.view_ids = train_view_ids + eval_view_ids
        train_viewpoints = train_env.collect_obs(train_scan_ids, train_view_ids)  # (all scan num, max node num, 36, 1000)
        eval_viewpoints = eval_env.collect_obs(eval_scan_ids, eval_view_ids)
        viewpoints = torch.cat([train_viewpoints, eval_viewpoints], dim=0)
        self.obs_emb = self.observe_model.encode(viewpoints.view((viewpoints.shape[0] * viewpoints.shape[1], 36, 1000)))  # (all scan num, max node num. 36, 128)
        self.obs_emb = self.obs_emb.view((viewpoints.shape[0], viewpoints.shape[1], 36, Param.emb_size))
        self.scan_ids_dict = {cur_scan_id: i for i, cur_scan_id in enumerate(self.scan_ids)}
        self.view_ids_dict = [{cur_view_id: i for i, cur_view_id in enumerate(cur_batch_view_ids)}
                              for cur_batch_view_ids in self.view_ids]

    def cal_loss(self, msg_b_prob, guess_b_prob, rewards_msg_b, rewards_loc):
        loss_msg = - torch.sum(msg_b_prob * rewards_msg_b.unsqueeze(2)) / msg_b_prob.shape[1]
        loss_guess = - torch.sum(guess_b_prob * rewards_loc)
        return loss_msg + loss_guess
