import torch
from torch import nn
from AgentA import *
from AgentB import *
from Environment import *


class LocalizationPos(nn.Module):
    def __init__(self):
        super(LocalizationPos, self).__init__()
        self.agent_a = PosAgentA()
        self.agent_b = PosAgentB()

    def forward(self, scan_ids, cand_view_ids, tgt_idxes, env: FullEnvDataset, choose_method="sample"):
        history_a, history_b = torch.full((Param.batch_size, 0), 0), torch.full((Param.batch_size, 0), 0)
        pos = [(random.choice(range(0, 360, 30)), random.choice([-30, 0, 30])) for _ in
                range(Param.batch_size)]  # (1, batch)
        tgt_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(tgt_idxes, cand_view_ids)]
        observe_a = env.collect_single_obs_multi_views(scan_ids, tgt_ids, pos).to(torch.float32)
        # observe_b = env.collect_obs(scan_ids, cand_view_ids).to(torch.float32)
        # viewpoint_vec_b = self.agent_b.observe_model.encode(observe_b.view((Param.batch_size * Param.max_num_node, 36, 1000)))
        # viewpoint_vec_b = viewpoint_vec_b.view((Param.batch_size, Param.max_num_node, 36, Param.emb_size))
        viewpoint_vec_b = []
        for cur_scan_id, cur_cand_view_ids in zip(scan_ids, cand_view_ids):
            scan_idx = self.agent_b.scan_ids_dict[cur_scan_id]
            view_idx = [self.agent_b.view_ids_dict[scan_idx][cur_view_id] for cur_view_id in cur_cand_view_ids]
            viewpoint_vec_b.append(self.agent_b.obs_emb[scan_idx, torch.from_numpy(np.array(view_idx)), :, :])  # (cand num, 36, 128)
            viewpoint_vec_b[-1] = torch.cat([viewpoint_vec_b[-1],
                                             torch.zeros((Param.max_num_node - viewpoint_vec_b[-1].shape[0],
                                                          36, Param.emb_size))], dim=0)
        viewpoint_vec_b = torch.stack(viewpoint_vec_b, dim=0)  # (batch, cand num, 36, 128)

        total_msg_a_prob, total_msg_b_prob = [], []
        total_guess_b_prob, total_guess_b = [], []
        for i in range(Param.max_turns):
            # ------ a -------
            history_a, msg_a, msg_a_prob = self.agent_a(history_a, observe_a, choose_method=choose_method)
            sep_token = torch.full((history_b.shape[0], 1), Param.tokens["<msg A>"]["pos"])
            history_b = torch.cat([history_b, sep_token, msg_a], dim=1)
            # ------ b -------
            history_b, guess_b, guess_b_prob, msg_b, msg_b_prob = self.agent_b(history_b, cand_view_ids,
                                                                               viewpoint_vec_b,
                                                                               choose_method=choose_method)
            sep_token = torch.full((history_a.shape[0], 1), Param.tokens["<msg B>"]["pos"])
            history_a = torch.cat([history_a, sep_token, msg_b], dim=1)
            # ----- collect prob -----
            total_msg_a_prob.append(msg_a_prob.unsqueeze(1)); total_msg_b_prob.append(msg_b_prob.unsqueeze(1))
            total_guess_b_prob.append(guess_b_prob.unsqueeze(1)); total_guess_b.append(guess_b.unsqueeze(1))
        total_msg_a_prob = torch.cat(total_msg_a_prob, dim=1)  # (batch, turns, sent len)
        total_msg_b_prob = torch.cat(total_msg_b_prob, dim=1)  # (batch, turns, sent len)
        total_guess_b_prob = torch.cat(total_guess_b_prob, dim=1)  # (batch, turns)
        total_guess_b = torch.cat(total_guess_b, dim=1)  # (batch, turns)
        return history_a, history_b, total_msg_a_prob, total_msg_b_prob, total_guess_b_prob, total_guess_b

    def backward_alg_blu(self, msg_a_prob, msg_b_prob, guess_b_prob, rewards):
        loss_a = self.agent_a.cal_loss_alg_blu(msg_a_prob, rewards)
        loss_b = self.agent_b.cal_loss_alg_blu(msg_b_prob, guess_b_prob, rewards)
        loss = loss_a + loss_b
        loss.backward()
        return loss_a.item(), loss_b.item()

    def backward(self, msg_a_prob, msg_b_prob, guess_b_prob, rewards):
        loss_a = self.agent_a.cal_loss(msg_a_prob, rewards)
        loss_b = self.agent_b.cal_loss(msg_b_prob, guess_b_prob, rewards)
        loss = loss_a + loss_b
        loss.backward()
        return loss_a.item(), loss_b.item()