import torch
from torch import nn
from Param import *


class PosObservation(nn.Module):
    def __init__(self, role):
        super(PosObservation, self).__init__()
        if role == 'A': model_pth = Param.obs_encode_model_a_pth
        else: model_pth = Param.obs_encode_model_b_pth
        temp = torch.load(model_pth)
        if type(temp) == dict:
            self.encoder = temp["model"]
            self.stop_token = temp["stop_token"]
            self.stop_token.weight.requires_grad = False
        else:
            self.encoder = temp
        self.encoder.requires_grad = False

    def forward(self, obs, view_idx=None):
        """
        :param obs: (batch, n, 1000)
        :param view_idx: (batch, n)
        :return: (batch, emb dim)
        """
        return self.encoder(obs, view_idx).detach()

    def encode_with_mask(self, obs, mask, view_idx=None):
        return self.encoder.encode_with_mask(obs, mask, view_idx).detach()

    def encode(self, obs, view_idx=None):
        if view_idx is None: obs_emb = self.encoder.linear(obs) + self.encoder.pe.weight.unsqueeze(0)
        else: obs_emb = self.encoder.linear(obs) + self.encoder.pe(view_idx)
        res = self.encoder.encoder(obs_emb)  # (batch, n, emb dim)
        return res.detach()  # (batch, n, emb dim)
