import torch
from torch import nn
from AgentA import *
from AgentB import *
import Util
from Environment import *


class Guidance(nn.Module):
    def __init__(self):
        super(Guidance, self).__init__()
        self.agent_a = GuideAgentA()
        self.agent_b = GuideAgentB()

    def forward(self, scan_ids, cand_view_ids, start_idxes, end_idxes, env: FullEnvDataset, choose_method="sample"):
        history_a, history_b = torch.full((Param.batch_size, 0), 0), torch.full((Param.batch_size, 0), 0)
        poses = [[(random.choice(range(0, 360, 30)), random.choice([-30, 0, 30])) for _ in
                  range(Param.batch_size)]]  # (1, batch)
        start_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(start_idxes, cand_view_ids)]
        end_ids = [cur_cand_view_ids[i] for i, cur_cand_view_ids in zip(end_idxes, cand_view_ids)]
        view_ids_a = start_ids

        total_msg_a_prob, total_msg_b_prob = [], []
        total_act_a_prob = []
        path_a = [torch.from_numpy(np.array(start_idxes)).unsqueeze(1)]
        for i in range(Param.max_turns):
            # -------- b ------------
            true_loc = torch.zeros((Param.batch_size, Param.max_num_node))
            true_loc[torch.arange(0, Param.batch_size), path_a[-1].squeeze()] = 1.0
            history_b, msg_b, msg_b_prob = self.agent_b(history_b, env, scan_ids, end_ids, cand_view_ids, true_loc,
                                                        choose_method=choose_method)
            sep_token = torch.full((history_b.shape[0], 1), Param.tokens["<msg B>"]["pos"])
            history_a = torch.cat([history_a, sep_token, msg_b], dim=1)
            # ------- a ------------
            states_a = [(i, j, k) for i, j, k in zip(scan_ids, view_ids_a, poses[-1])]
            history_a, msg_a, msg_a_prob, view_ids_a, act_a_prob = self.agent_a(history_a, states_a, env,
                                                                                choose_method=choose_method)
            path_a.append(Util.id2idx1d(view_ids_a, cand_view_ids).unsqueeze(1))
            total_act_a_prob.append(act_a_prob)
            sep_token = torch.full((history_b.shape[0], 1), Param.tokens["<msg A>"]["pos"])
            history_b = torch.cat([history_b, sep_token, msg_a], dim=1)
            # ----- collect prob ------
            total_msg_a_prob.append(msg_a_prob); total_msg_b_prob.append(msg_b_prob)
        total_msg_a_prob = torch.stack(total_msg_a_prob, dim=0).transpose(1, 0)  # (batch, turns, sent len)
        total_msg_b_prob = torch.stack(total_msg_b_prob, dim=0).transpose(1, 0)  # (batch, turns, sent len)
        total_act_a_prob = torch.stack(total_act_a_prob, dim=0).transpose(1, 0)  # (batch, turns, neigh len)
        path_a = torch.cat(path_a, dim=1)  # (batch, turns)
        return history_a, history_b, total_msg_a_prob, total_msg_b_prob, total_act_a_prob, path_a

    def backward_alu_blg(self, msg_a_prob, msg_b_prob, act_a_prob, rewards):
        loss_a = self.agent_a.cal_loss_alu_blg(msg_a_prob, act_a_prob, rewards)
        loss_b = self.agent_b.cal_loss_alu_blg(msg_b_prob, rewards)
        loss = loss_a + loss_b
        loss.backward()
        return loss_a.item(), loss_b.item()

    def backward(self, msg_a_prob, msg_b_prob, act_a_prob, rewards):
        loss_a = self.agent_a.cal_loss(msg_a_prob, act_a_prob, rewards)
        loss_b = self.agent_b.cal_loss(msg_b_prob, rewards)
        loss = loss_a + loss_b
        loss.backward()
        return loss_a.item(), loss_b.item()
