
from nle import nethack
from torch import nn
import torch
from omegaconf import DictConfig
from transformers import BertConfig, BertModel

from il_scale.nethack.utils.model import conv_outdim
from il_scale.nethack.utils.model import selectt

class BottomLineEncoder(nn.Module):
    """
    Adapted from https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022/blob/67139262966aa11555cf7aca15723375b36fbe42/experiment_code/hackrl/models/offline_chaotic_dwarf.py
    """
    def __init__(self, cfg: DictConfig, hdim: int = 128):
        super(BottomLineEncoder, self).__init__()

        self.cfg = cfg
        self.use_true_blstats = cfg.network.use_true_blstats
        self.add_raw_time_encoding = cfg.network.add_raw_time_encoding
        self.add_log_time_encoding = cfg.network.add_log_time_encoding
        self.add_emb_time_encoding = cfg.network.add_emb_time_encoding
        self.add_layernorm_in_blstats_conv = cfg.network.add_layernorm_in_blstats_conv
        self.hdim = hdim

        if self.use_true_blstats:
            self.embed_blstats = nn.Sequential(
                nn.Linear(nethack.NLE_BLSTATS_SIZE, self.hdim),
                nn.LayerNorm(self.hdim),
                nn.ELU()
            )

        if self.add_raw_time_encoding or self.add_log_time_encoding:
            self.time_encoding = nn.Sequential(
                nn.Linear(1, 1),
                nn.LayerNorm(1),
                nn.ELU()
            )

        if self.add_emb_time_encoding:
            self.time_encoding = nn.Embedding(1000000, 16)

        if self.cfg.network.use_blstats_transformer:
            # embeddings
            self.char_embeddings = nn.Embedding(256, self.hdim)
            self.color_embeddings = nn.Embedding(128, self.hdim)

            bert_config = BertConfig(
                vocab_size=1,
                hidden_size=self.hdim,
                num_hidden_layers=self.cfg.network.blstats_tf_num_layers,
                num_attention_heads=self.cfg.network.blstats_tf_num_heads,
                intermediate_size=4*self.hdim,
                max_position_embeddings=2*nethack.NLE_TERM_CO  
            )
            self.blstats_fwd = BertModel(bert_config)
        else:
            self.conv_layers = []
            w = nethack.NLE_TERM_CO * 2
            for in_ch, out_ch, filter, stride in [[2, 32, 8, 4], [32, 64, 4, 1]]:
                self.conv_layers.append(nn.Conv1d(in_ch, out_ch, filter, stride=stride))
                if self.add_layernorm_in_blstats_conv:
                    self.conv_layers.append(nn.LayerNorm([out_ch, 39 if stride == 4 else 36]))
                self.conv_layers.append(nn.ELU(inplace=True))
                w = conv_outdim(w, filter, padding=0, stride=stride)

            self.out_dim = w * out_ch
            self.conv_net = nn.Sequential(*self.conv_layers)

            if self.use_true_blstats:
                mlp_input_dim = self.out_dim + self.hdim
            elif self.add_raw_time_encoding or self.add_log_time_encoding:
                mlp_input_dim = self.out_dim + 1
            elif self.add_emb_time_encoding:
                mlp_input_dim = self.out_dim + 16
            else:
                mlp_input_dim = self.out_dim

            if self.cfg.network.add_norm_after_linear:
                self.fwd_net = nn.Sequential(
                    nn.Linear(mlp_input_dim, self.hdim),
                    nn.LayerNorm(self.hdim),
                    nn.ELU(),
                    nn.Linear(self.hdim, self.hdim),
                    nn.LayerNorm(self.hdim),
                    nn.ELU(),
                )
            else:
                self.fwd_net = nn.Sequential(
                    nn.Linear(mlp_input_dim, self.hdim),
                    nn.ELU(),
                    nn.Linear(self.hdim, self.hdim),
                    nn.ELU(),
                )

            if self.cfg.network.fix_initialization:
                self.apply(self._init_weights)

    def _init_weights(self, module, scale: float = 1.0, bias: bool = True):
        if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):
            print('fixing bottom line initialization ...')
            module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)

            if bias:
                module.bias.data *= 0


    def forward(self, bottom_lines, blstats, bottom_line_colors):
        B, D = bottom_lines.shape
        if self.cfg.network.use_blstats_transformer:
            chars = selectt(self.char_embeddings, bottom_lines.long(), True)
            colors = selectt(self.color_embeddings, bottom_line_colors.long(), True)
            x = chars + colors
            return self.blstats_fwd(inputs_embeds=x).pooler_output
        else:
            # ASCII 32: ' ', ASCII [33-128]: visible characters
            chars_normalised = (bottom_lines - 32) / 96

            # ASCII [45-57]: -./01234556789
            numbers_mask = (bottom_lines > 44) * (bottom_lines < 58)
            digits_normalised = numbers_mask * (bottom_lines - 47) / 10 # why subtract 47 here and not 48?

            # Put in different channels & conv (B, 2, D)
            x = torch.stack([chars_normalised, digits_normalised], dim=1)

            if self.use_true_blstats:
                x = self.conv_net(x).view(B, -1)
                blstats = self.embed_blstats(blstats)
                x = torch.cat([x, blstats], dim=1)
            elif self.add_raw_time_encoding:
                x = self.conv_net(x).view(B, -1)
                time = blstats[:, nethack.NLE_BL_TIME: nethack.NLE_BL_TIME + 1]
                time = self.time_encoding(time)
                x = torch.cat([x, time], dim=1)
            elif self.add_log_time_encoding:
                x = self.conv_net(x).view(B, -1)
                time = blstats[:, nethack.NLE_BL_TIME: nethack.NLE_BL_TIME + 1]
                time = torch.log(time + 1)
                time = self.time_encoding(time)
                x = torch.cat([x, time], dim=1)
            elif self.add_emb_time_encoding:
                x = self.conv_net(x).view(B, -1)
                time = blstats[:, nethack.NLE_BL_TIME].long()
                time = self.time_encoding(time)
                x = torch.cat([x, time], dim=1)
            else:
                x = self.conv_net(x).view(B, -1)

            return self.fwd_net(x)