import math

import torch.nn as nn
import torch
from omegaconf import DictConfig

from il_scale.nethack.utils.model import selectt
from il_scale.nethack.networks.convnext import ConvNeXt

class EncoderConvNeXt(nn.Module):
    """
    TODO
    """
    def __init__(
        self, 
        cfg: DictConfig,
        id_pairs_table: torch.tensor,
        glyph_emb: nn.Embedding,
        group_emb: nn.Embedding,
        level_emb: nn.Embedding,
        difficulty_emb: nn.Embedding,
        weight_emb: nn.Embedding,
        speed_emb: nn.Embedding
    ):
        super(EncoderConvNeXt, self).__init__()

        self.cfg = cfg
        self.hdim = self.cfg.network.convnext_hdim
        self.glyph_embed_with_linear = self.cfg.network.glyph_embed_with_linear

        if self.glyph_embed_with_linear:
            self.emb_linear = nn.Linear(
                self.cfg.network.char_edim + self.cfg.network.color_edim + self.cfg.network.glyph_edim + self.cfg.network.group_edim + self.cfg.network.level_edim + self.cfg.network.difficulty_edim + self.cfg.network.weight_edim + self.cfg.network.speed_edim,
                self.cfg.network.glyph_edim
            )

        self.id_pairs_table = id_pairs_table
        self.glyph_emb = glyph_emb
        self.group_emb = group_emb
        self.level_emb = level_emb
        self.difficulty_emb = difficulty_emb
        self.weight_emb = weight_emb
        self.speed_emb = speed_emb

        self.char_embeddings = nn.Embedding(256, self.cfg.network.char_edim)
        self.color_embeddings = nn.Embedding(128, self.cfg.network.color_edim)

        self.conv_net = ConvNeXt(in_chans=self.cfg.network.char_edim, depths=[3, 9, 3], dims=[96, 384, self.hdim])
        self.out_size = self.hdim

    def forward(self, chars, colors, glyphs):
        chars, colors, glyphs, groups, levels, difficulties, weights, speeds = self._embed(chars, colors, glyphs) # 21 x 80
        if self.cfg.network.add_char_color:
            glyphs = glyphs + groups + levels + difficulties + weights + speeds
            glyphs = torch.cat([glyphs, torch.zeros_like(glyphs)[..., -1:, :].to(glyphs.device)], dim=-2)
            x = chars + colors + glyphs
            x = x.permute(0, 1, 4, 2, 3).flatten(1, 2).contiguous()
        else:
            x = self._stack(chars, colors, glyphs, groups, levels, difficulties, weights, speeds)
        x = self.conv_net(x)
        return x

    def _embed(self, chars, colors, glyphs):
        chars = selectt(self.char_embeddings, chars.long(), True)
        colors = selectt(self.color_embeddings, colors.long(), True)

        # get all features
        B, C, H, W = glyphs.shape
        ids = self.id_pairs_table.index_select(0, glyphs.view(-1).long())
        glyphs = ids.select(1, 0).view(B, C, H, W).long()
        groups = ids.select(1, 1).view(B, C, H, W).long()
        difficulties = ids.select(1, 2).view(B, C, H, W).long()
        levels = ids.select(1, 3).view(B, C, H, W).long()
        weights = ids.select(1, 4).view(B, C, H, W).long()
        speeds = ids.select(1, 5).view(B, C, H, W).long()

        # embed
        glyph_embs = selectt(self.glyph_emb, glyphs, True)
        group_embs = selectt(self.group_emb, groups, True)
        level_embs = selectt(self.level_emb, levels, True)
        difficulty_embs = selectt(self.difficulty_emb, difficulties, True)
        weight_embs = selectt(self.weight_emb, weights, True)
        speed_embs = selectt(self.speed_emb, speeds, True)

        return chars, colors, glyph_embs, group_embs, level_embs, difficulty_embs, weight_embs, speed_embs

    def _stack(self, chars, colors, glyphs, groups, levels, difficulties, weights, speeds):
        glyphs = torch.cat([glyphs, torch.zeros_like(glyphs)[..., -1:, :].to(glyphs.device)], dim=-2)
        groups = torch.cat([groups, torch.zeros_like(groups)[..., -1:, :].to(groups.device)], dim=-2)
        levels = torch.cat([levels, torch.zeros_like(levels)[..., -1:, :].to(levels.device)], dim=-2)
        difficulties = torch.cat([difficulties, torch.zeros_like(difficulties)[..., -1:, :].to(difficulties.device)], dim=-2)
        weights = torch.cat([weights, torch.zeros_like(weights)[..., -1:, :].to(weights.device)], dim=-2)
        speeds = torch.cat([speeds, torch.zeros_like(speeds)[..., -1:, :].to(speeds.device)], dim=-2)

        obs = torch.cat([chars, colors, glyphs, groups, levels, difficulties, weights, speeds], dim=-1)
        if self.glyph_embed_with_linear:
            obs = self.emb_linear(obs)

        return obs.permute(0, 1, 4, 2, 3).flatten(1, 2).contiguous()