
from nle import nethack
from torch import nn
import torch.nn.functional as F
from omegaconf import DictConfig
from transformers import BertConfig, BertModel

from il_scale.nethack.utils.model import selectt


class TopLineEncoder(nn.Module):
    """
    This class uses a one-hot encoding of the ASCII characters
    as features that get fed into an MLP. 
    Adapted from https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022/blob/67139262966aa11555cf7aca15723375b36fbe42/experiment_code/hackrl/models/offline_chaotic_dwarf.py
    """
    def __init__(self, cfg: DictConfig, msg_hdim: int):
        super(TopLineEncoder, self).__init__()

        self.cfg = cfg
        
        self.msg_hdim = msg_hdim
        self.i_dim = nethack.NLE_TERM_CO * 256
        print('msg hdim', msg_hdim)

        if self.cfg.network.use_message_transformer:
            # embeddings
            self.char_embeddings = nn.Embedding(256, msg_hdim)

            bert_config = BertConfig(
                vocab_size=1,
                hidden_size=self.msg_hdim,
                num_hidden_layers=self.cfg.network.message_tf_num_layers,
                num_attention_heads=self.cfg.network.message_tf_num_heads,
                intermediate_size=4*self.msg_hdim,
                max_position_embeddings=nethack.NLE_TERM_CO  
            )
            self.msg_fwd = BertModel(bert_config)
        else:
            if self.cfg.network.add_norm_after_linear:
                print('Adding norm in topline ... ')
                self.msg_fwd = nn.Sequential(
                    nn.Linear(self.i_dim, self.msg_hdim), # NOTE: anonymous: this first layer is pretty much an embedding layer
                    nn.LayerNorm(self.msg_hdim),
                    nn.ELU(inplace=True),
                    nn.Linear(self.msg_hdim, self.msg_hdim),
                    nn.LayerNorm(self.msg_hdim),
                    nn.ELU(inplace=True),
                )

            else:
                self.msg_fwd = nn.Sequential(
                    nn.Linear(self.i_dim, self.msg_hdim), # NOTE: anonymous: this first layer is pretty much an embedding layer
                    nn.ELU(inplace=True),
                    nn.Linear(self.msg_hdim, self.msg_hdim),
                    nn.ELU(inplace=True),
                )

            if self.cfg.network.fix_initialization:
                self.apply(self._init_weights)

    def _init_weights(self, module, scale: float = 1.0, bias: bool = True):
        if isinstance(module, nn.Linear):
            print('fixing topline initialization ...')
            module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)

            if bias:
                module.bias.data *= 0

    def forward(self, message):
        if self.cfg.network.use_message_transformer:
            chars = selectt(self.char_embeddings, message.long(), True)
            return self.msg_fwd(inputs_embeds=chars).pooler_output
        else:
            # Characters start at 33 in ASCII and go to 128. 96 = 128 - 32
            message_normed = (
                F.one_hot((message).long(), 256).reshape(-1, self.i_dim).float()
            )
            return self.msg_fwd(message_normed)