
from nle import nethack
from nle.nethack import INV_SIZE
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertConfig, BertModel
from omegaconf import DictConfig

from il_scale.nethack.utils.model import selectt

class InventoryNet(nn.Module):
    """
    Encodes the inventory strings.
    """
    def __init__(
        self, 
        cfg: DictConfig,
        id_pairs_table: torch.tensor,
        glyph_emb: nn.Embedding, 
        group_emb: nn.Embedding,
        level_emb: nn.Embedding,
        difficulty_emb: nn.Embedding,
        weight_emb: nn.Embedding,
        speed_emb: nn.Embedding,
        corpse_emb: nn.Embedding,
        sacrifice_emb: nn.Embedding,
        inv_hdim: int
    ):
        super(InventoryNet, self).__init__()
        
        self.cfg = cfg
        self.id_pairs_table = id_pairs_table
        self.inv_hdim = inv_hdim
        self.glyph_embed_with_linear = self.cfg.network.glyph_embed_with_linear

        self.glyph_emb = glyph_emb
        self.group_emb = group_emb
        self.level_emb = level_emb
        self.difficulty_emb = difficulty_emb
        self.weight_emb = weight_emb
        self.speed_emb = speed_emb
        self.corpse_emb = corpse_emb
        self.sacrifice_emb = sacrifice_emb

        if self.glyph_embed_with_linear:
            self.emb_linear = nn.Linear(
                self.cfg.network.glyph_edim + self.cfg.network.group_edim + self.cfg.network.level_edim + self.cfg.network.difficulty_edim + self.cfg.network.weight_edim + self.cfg.network.speed_edim + self.cfg.network.corpse_edim + self.cfg.network.sacrifice_edim,
                self.cfg.network.glyph_edim
            )

        self.mlp = nn.Sequential(
            nn.Linear(INV_SIZE[0] * self.cfg.network.glyph_edim, self.inv_hdim),
            nn.LayerNorm(self.inv_hdim),
            nn.ELU(),
            nn.Linear(self.inv_hdim, self.inv_hdim)
        )

    def forward(self, inv_glyphs: torch.Tensor):
        B = inv_glyphs.shape[0]

        # get all features
        ids = self.id_pairs_table.index_select(0, inv_glyphs.view(-1))
        glyphs = ids.select(1, 0).view(-1, 1).long()
        groups = ids.select(1, 1).view(-1, 1).long()
        difficulties = ids.select(1, 2).view(-1, 1).long()
        levels = ids.select(1, 3).view(-1, 1).long()
        weights = ids.select(1, 4).view(-1, 1).long()
        speeds = ids.select(1, 5).view(-1, 1).long()
        corpses = ids.select(1, 6).view(-1, 1).long()
        sacrifices = ids.select(1, 7).view(-1, 1).long()

        # embed
        glyph_embs = selectt(self.glyph_emb, glyphs, True)
        group_embs = selectt(self.group_emb, groups, True)
        level_embs = selectt(self.level_emb, levels, True)
        difficulty_embs = selectt(self.difficulty_emb, difficulties, True)
        weight_embs = selectt(self.weight_emb, weights, True)
        speed_embs = selectt(self.speed_emb, speeds, True)
        corpse_embs = selectt(self.corpse_emb, corpses, True)
        sacrifice_embs = selectt(self.sacrifice_emb, sacrifices, True)

        if self.glyph_embed_with_linear:
            # reshape
            glyph_embs = glyph_embs.view(B, INV_SIZE[0], self.cfg.network.glyph_edim)
            group_embs = group_embs.view(B, INV_SIZE[0], self.cfg.network.group_edim)
            level_embs = level_embs.view(B, INV_SIZE[0], self.cfg.network.level_edim)
            difficulty_embs = difficulty_embs.view(B, INV_SIZE[0], self.cfg.network.difficulty_edim)
            weight_embs = weight_embs.view(B, INV_SIZE[0], self.cfg.network.weight_edim)
            speed_embs = speed_embs.view(B, INV_SIZE[0], self.cfg.network.speed_edim)
            corpse_embs = corpse_embs.view(B, INV_SIZE[0], self.cfg.network.corpse_edim)
            sacrifice_embs = sacrifice_embs.view(B, INV_SIZE[0], self.cfg.network.sacrifice_edim)

            # concatenate everything
            inv_emb = torch.cat([glyph_embs, group_embs, level_embs, difficulty_embs, weight_embs, speed_embs, corpse_embs, sacrifice_embs], dim=-1)

            # linear
            inv_emb = self.emb_linear(inv_emb)

            # reshape
            inv_emb = inv_emb.view(B, INV_SIZE[0] * self.cfg.network.glyph_edim)

        else:
            # reshape
            glyph_embs = glyph_embs.view(B, INV_SIZE[0] * self.cfg.network.glyph_edim)
            group_embs = group_embs.view(B, INV_SIZE[0] * self.cfg.network.group_edim)
            level_embs = level_embs.view(B, INV_SIZE[0] * self.cfg.network.level_edim)
            difficulty_embs = difficulty_embs.view(B, INV_SIZE[0] * self.cfg.network.difficulty_edim)
            weight_embs = weight_embs.view(B, INV_SIZE[0] * self.cfg.network.weight_edim)
            speed_embs = speed_embs.view(B, INV_SIZE[0] * self.cfg.network.speed_edim)
            corpse_embs = corpse_embs.view(B, INV_SIZE[0] * self.cfg.network.corpse_edim)
            sacrifice_embs = sacrifice_embs.view(B, INV_SIZE[0] * self.cfg.network.sacrifice_edim)

            # sum 
            inv_emb = glyph_embs + group_embs + level_embs + difficulty_embs + weight_embs + speed_embs + corpse_embs + sacrifice_embs

        inv_rep = self.mlp(inv_emb)

        return inv_rep
        

        