import torch
from torch import nn
from AgentA import *
from AgentB import *
import Util


class NavigationMultiStep(nn.Module):
    def __init__(self):
        super(NavigationMultiStep, self).__init__()
        self.agent_a = NavMultiStepAgentA()
        self.agent_b = NavMultiStepAgentB()

    def forward(self, scan_ids, cand_view_ids, start_idxes, end_idxes, env: FullEnvDataset, choose_method="sample"):
        history_a, history_b = torch.full((Param.batch_size, 0), 0), torch.full((Param.batch_size, 0), 0)
        history_path_vec = None
        poses = [[(random.choice(range(0, 360, 30)), random.choice([-30, 0, 30])) for _ in
                  range(Param.batch_size)]]  # (1, batch)
        start_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(start_idxes, cand_view_ids)]
        end_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(end_idxes, cand_view_ids)]
        view_ids_a = start_ids

        # observe_b = env.collect_obs(scan_ids, cand_view_ids).to(torch.float32)
        viewpoint_vec_b = []
        for cur_scan_id, cur_cand_view_ids in zip(scan_ids, cand_view_ids):
            scan_idx = self.agent_b.scan_ids_dict[cur_scan_id]
            view_idx = [self.agent_b.view_ids_dict[scan_idx][cur_view_id] for cur_view_id in cur_cand_view_ids]
            viewpoint_vec_b.append(self.agent_b.obs_emb[scan_idx, torch.from_numpy(np.array(view_idx)), :, :])  # (cand num, 36, 128)
            viewpoint_vec_b[-1] = torch.cat([viewpoint_vec_b[-1],
                                             torch.zeros((Param.max_num_node - viewpoint_vec_b[-1].shape[0], 36, Param.emb_size))], dim=0)
        viewpoint_vec_b = torch.stack(viewpoint_vec_b, dim=0)  # (batch, cand num, 36, 128)

        total_msg_a_prob, total_msg_b_prob = [], []
        total_act_a_prob, total_guess_b_prob = [], []
        total_guess_b, total_is_speak = [], []
        path_a = [torch.from_numpy(np.array(start_idxes)).unsqueeze(1)]

        for i in range(Param.max_turns):
            # ------- a ----------
            states_a = [(i, j, k) for i, j, k in zip(scan_ids, view_ids_a, poses[-1])]
            history_a, history_path_vec, msg_a, msg_a_prob, is_speak, view_ids_a, act_a_prob = \
                self.agent_a(history_a, history_path_vec, states_a, env, choose_method=choose_method)
            path_a.append(Util.id2idx1d(view_ids_a, cand_view_ids).unsqueeze(1))
            total_act_a_prob.append(act_a_prob)
            total_msg_a_prob.append(msg_a_prob)
            total_is_speak.append(is_speak.unsqueeze(1))
            sep_token = torch.full((history_b.shape[0], 1), Param.tokens["<msg A>"]["pos"])
            history_b = torch.cat([history_b, sep_token, msg_a], dim=1)
            # -------- b -----------
            if self.training is True:
                true_loc = torch.zeros((Param.batch_size, Param.max_num_node))
                true_loc[torch.arange(0, Param.batch_size), path_a[-1].squeeze()] = 1.0
            else:
                true_loc = None
            history_b, guess_b, guess_b_prob, msg_b, msg_b_prob = \
                self.agent_b(history_b, viewpoint_vec_b, env, scan_ids, end_ids, cand_view_ids, is_speak,
                             true_loc=true_loc, choose_method=choose_method, is_true_loc=True if true_loc is not None else False)
            total_guess_b.append(guess_b.unsqueeze(1))
            total_guess_b_prob.append(guess_b_prob)
            total_msg_b_prob.append(msg_b_prob)
            sep_token = torch.full((history_b.shape[0], 1), Param.tokens["<msg B>"]["pos"])
            history_a = torch.cat([history_a, sep_token, msg_b], dim=1)
        total_msg_a_prob = torch.stack(total_msg_a_prob, dim=0).transpose(1, 0)
        total_msg_b_prob = torch.stack(total_msg_b_prob, dim=0).transpose(1, 0)
        total_act_a_prob = torch.stack(total_act_a_prob, dim=0).transpose(1, 0)
        total_guess_b_prob = torch.stack(total_guess_b_prob, dim=0).transpose(1, 0)
        path_a = torch.cat(path_a, dim=1)
        total_guess_b = torch.cat(total_guess_b, dim=1)
        total_is_speak = torch.cat(total_is_speak, dim=1)
        return history_a, history_b, total_msg_a_prob, total_msg_b_prob, total_act_a_prob, total_guess_b_prob, \
               path_a, total_guess_b, total_is_speak

    def backward(self, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, msg_a_gain, msg_b_gain, loc_rewards, nav_rewards):
        loss_a = self.agent_a.cal_loss(msg_a_prob, act_a_prob, msg_a_gain, nav_rewards)
        loss_b = self.agent_b.cal_loss(msg_b_prob, guess_b_prob, msg_b_gain, loc_rewards)
        loss = loss_a + loss_b
        loss.backward()
        return loss_a.item(), loss_b.item()
