import torch
from torch import nn
from Components import *
import torch.nn.functional as F
from Util import *


class Language(nn.Module):
    def __init__(self, role):
        super(Language, self).__init__()
        self.codec = CodecTransformer()
        self.encode_embeddings = nn.Embedding(num_embeddings=len(Param.tokens_list), embedding_dim=Param.emb_size)
        self.decode_embeddings = nn.Embedding(num_embeddings=len(Param.tokens_list), embedding_dim=Param.emb_size)
        self.role = role

    def encode(self, history:torch.Tensor, mask=None):
        """
        :param history: token idxes
        :param mask
        :return: encoded history
        """
        history_embs = self.encode_embeddings(history)
        return self.codec.encode(history_embs, mask=mask)

    def decode_msg(self, context_vec: torch.Tensor, choose_method="sample", mem_mask=None):
        """
        :param context_vec: vector including language and observation part
        :param choose_method: sample or greedy
        :param mem_mask
        :return:
        """
        hx = context_vec
        token_before = torch.full((Param.batch_size, 1), Param.tokens["<msg {}>".format(self.role)]["pos"])
        token_before_emb = self.decode_embeddings(token_before)
        assert token_before_emb.shape == (Param.batch_size, 1, Param.emb_size)
        decoded_tokens = []; decoded_tokens_prob = []
        for i in range(Param.sent_len):
            output, hx = self.codec.decode(token_before_emb, hx, mem_mask=mem_mask)
            output = output[:, -1, :]
            assert output.shape == (Param.batch_size, Param.emb_size)
            # output -> (batch, emb size), weight -> (cand len, emb size)
            next_token_score = F.softmax(F.layer_norm(torch.mm(output, self.decode_embeddings.weight.transpose(1, 0)),
                                                      normalized_shape=(len(Param.tokens_list),)), dim=-1)
            next_token, next_token_prob = choose_token(next_token_score, "<msg {}>".format(self.role),
                                                       choose_method=choose_method)
            decoded_tokens.append(next_token); decoded_tokens_prob.append(next_token_prob)
            next_token_emb = self.decode_embeddings(next_token)  # (batch, emb size)
            token_before_emb = torch.cat([token_before_emb, next_token_emb.unsqueeze(1)], dim=1)
        decoded_tokens = torch.stack(decoded_tokens, dim=0).transpose(1, 0)  # (batch, sent len)
        decoded_tokens_prob = torch.stack(decoded_tokens_prob, dim=0).transpose(1, 0)  # (batch, sent len)
        return decoded_tokens, decoded_tokens_prob

