import torch
from torch import nn
from Param import *
from PositionEncoding import *


class Codec(nn.Module):
    """
    Codec means Encoder and Decoder
    This is the Parent for all codec
    """
    def __init__(self):
        super(Codec, self).__init__()

    def encode(self, history_emb: torch.Tensor) -> torch.Tensor:
        """
        :param history_emb: history to be encoded
        :param mask: mask
        :return: encoded vector
        """
        pass

    def decode(self, history_emb: torch.Tensor, hx: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        """
        decode a whole turn info
        :param history_emb: history to be encoded and then decoded
        :param hx: hidden state
        :return: output, hidden state
        """
        pass


class CodecTransformer(Codec):
    def __init__(self):
        super(CodecTransformer, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=Param.emb_size, nhead=Param.trans_nhead,
                                                   dim_feedforward=Param.trans_feedforward, dropout=Param.trans_dropout, batch_first=True)
        decoder_layer = nn.TransformerDecoderLayer(d_model=Param.emb_size, nhead=Param.trans_nhead,
                                                   dim_feedforward=Param.trans_feedforward, dropout=Param.trans_dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=Param.trans_layernum)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=Param.trans_layernum)
        self.pe = PositionalEncoding(d_model=Param.emb_size)

    def encode(self, history_emb, mask=None):
        """
        :param history_emb:
        :param mask:
        :return:
        """
        pos_history_emb = self.pe(history_emb)
        if mask is not None:
            history_vec = self.encoder(pos_history_emb, src_key_padding_mask=mask)
        else:
            history_vec = self.encoder(pos_history_emb)
        return history_vec

    def decode(self, token_before_emb, hx, mask=None, mem_mask=None):
        """
        :param token_before_emb:
        :param hx:
        :return:
        """
        pos_token_before_emb = self.pe(token_before_emb)
        if mask is not None:
            output = self.decoder(pos_token_before_emb, hx, tgt_key_padding_mask=mask, memory_key_padding_mask=mem_mask)
        if mem_mask is not None:
            output = self.decoder(pos_token_before_emb, hx, memory_key_padding_mask=mem_mask)
        else:
            output = self.decoder(pos_token_before_emb, hx)
        return output, hx


class NavObsTransEncoder(nn.Module):
    def __init__(self):
        super(NavObsTransEncoder, self).__init__()
        self.pe = nn.Embedding(36, Param.emb_size)
        # self.stop_token = nn.Embedding(1, Param.emb_size)
        # self.cls = nn.Embedding(1, Param.emb_size)
        self.linear = nn.Sequential(
            nn.Linear(1000, Param.emb_size),
            nn.ReLU()
        )
        encoder_layer = nn.TransformerEncoderLayer(d_model=Param.emb_size, nhead=8, dim_feedforward=128, dropout=0.2, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

    def forward(self, obs, view_idx=None):
        """
        :param obs: (batch, n, 1000)
        :param view_idx: (batch, n)
        :return:
        """
        if view_idx is None:
            obs_emb = self.linear(obs) + self.pe.weight.unsqueeze(0)
        else:
            obs_emb = self.linear(obs) + self.pe(view_idx)
        return torch.mean(self.encoder(obs_emb), dim=1)
        # cls_emb = self.cls(torch.zeros(Param.batch_size).to(torch.int64)).unsqueeze(1)  # (batch, 1, emb dim)
        # return self.encoder(torch.cat([cls_emb, obs_emb], dim=1))[:, 0, :]

    def encode_with_mask(self, obs, mask, view_idx=None):
        """
        :param obs: (batch, n, 1000)
        :param mask: (batch, n)
        :param view_idx: (batch, n)
        :return: (batch, emb dim)
        """
        if view_idx is None: obs_emb = self.linear(obs) + self.pe.weight.unsqueeze(0)
        else: obs_emb = self.linear(obs) + self.pe(view_idx)
        res = self.encoder(obs_emb, src_key_padding_mask=mask)  # (batch, n, emb dim)
        res = res * (1 - mask).unsqueeze(2)
        return torch.mean(res, dim=1)