import math

import torch.nn as nn
import torch
from omegaconf import DictConfig

from il_scale.nethack.utils.model import selectt

class CharColorEncoderResnet(nn.Module):
    """
    Inspired by network from IMPALA https://arxiv.org/pdf/1802.01561.pdf
    """
    def __init__(
        self, 
        obs_shape,
        cfg: DictConfig,
        id_pairs_table: torch.tensor,
        glyph_emb: nn.Embedding,
        group_emb: nn.Embedding,
        level_emb: nn.Embedding,
        difficulty_emb: nn.Embedding,
        weight_emb: nn.Embedding,
        speed_emb: nn.Embedding,
        corpse_emb: nn.Embedding,
        sacrifice_emb: nn.Embedding
    ):
        super(CharColorEncoderResnet, self).__init__()

        self.cfg = cfg
        self.resnet_scale_channels = self.cfg.network.resnet_scale_channels
        self.hdim = self.cfg.network.resnet_hdim
        self.h, self.w = obs_shape
        self.glyph_embed_with_linear = self.cfg.network.glyph_embed_with_linear

        if self.glyph_embed_with_linear:
            self.emb_linear = nn.Linear(
                self.cfg.network.char_edim + self.cfg.network.color_edim + self.cfg.network.glyph_edim + self.cfg.network.group_edim + self.cfg.network.level_edim + self.cfg.network.difficulty_edim + self.cfg.network.weight_edim + self.cfg.network.speed_edim + self.cfg.network.corpse_edim + self.cfg.network.sacrifice_edim,
                self.cfg.network.glyph_edim
            )

        self.id_pairs_table = id_pairs_table
        self.glyph_emb = glyph_emb
        self.group_emb = group_emb
        self.level_emb = level_emb
        self.difficulty_emb = difficulty_emb
        self.weight_emb = weight_emb
        self.speed_emb = speed_emb
        self.corpse_emb = corpse_emb
        self.sacrifice_emb = sacrifice_emb

        self.blocks = []

        self.conv_params = [
            [cfg.network.char_edim * cfg.network.obs_frame_stack, 16 * self.resnet_scale_channels, cfg.network.obs_kernel_size, cfg.network.resnet_num_blocks],
            [16 * self.resnet_scale_channels, 32 * self.resnet_scale_channels, cfg.network.obs_kernel_size, cfg.network.resnet_num_blocks],
            [32 * self.resnet_scale_channels, 32 * self.resnet_scale_channels, cfg.network.obs_kernel_size, cfg.network.resnet_num_blocks]
        ]

        print('resnet scale channels', self.resnet_scale_channels)
        self.conv_params = self.conv_params[:self.cfg.network.obs_conv_blocks]

        for (
            in_channels,
            out_channels,
            filter_size,
            num_res_blocks
        ) in self.conv_params:
            block = []
            # Downsample
            block.append(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    filter_size,
                    stride=1,
                    padding=(filter_size // 2)
                )
            )

            if self.cfg.network.fix_initialization:
                print('fixing resnet first conv initialization ...')
                block[-1].weight.data *= 1.0 / block[-1].weight.norm(
                    dim=tuple(range(1, block[-1].weight.data.ndim)), p=2, keepdim=True
                )

                block[-1].bias.data *= 0

            block.append(
                nn.MaxPool2d(
                    kernel_size=3,
                    stride=2
                )
            )
            self.h = math.floor((self.h - 1 * (3 - 1) - 1)/2 + 1) # from PyTorch Docs
            self.w = math.floor((self.w - 1 * (3 - 1) - 1)/2 + 1) # from PyTorch Docs

            # Residual block(s)
            for _ in range(num_res_blocks):
                block.append(ResBlock(cfg, out_channels, out_channels, filter_size, self.cfg.network.resnet_num_layers))
            self.blocks.append(nn.Sequential(*block))

        self.conv_net = nn.Sequential(*self.blocks)
        self.out_size = self.h * self.w * out_channels

        print('resnet out size', self.out_size)
        if self.cfg.network.add_norm_after_linear:
            print('Adding norm resnet linears ... ')
            fc_layers = [nn.Linear(self.out_size, self.cfg.network.resnet_hdim), nn.LayerNorm(self.cfg.network.resnet_hdim), nn.ELU(inplace=True)]
            for _ in range(self.cfg.network.resnet_num_fc_layers - 1):
                fc_layers.append(nn.Linear(self.cfg.network.resnet_hdim, self.cfg.network.resnet_hdim))
                fc_layers.append(nn.LayerNorm(self.cfg.network.resnet_hdim))
                fc_layers.append(nn.ELU(inplace=True))
            self.fc_head = nn.Sequential(*fc_layers)
        else:
            fc_layers = [nn.Linear(self.out_size, self.cfg.network.resnet_hdim), nn.ELU(inplace=True)]
            for _ in range(self.cfg.network.resnet_num_fc_layers - 1):
                fc_layers.append(nn.Linear(self.cfg.network.resnet_hdim, self.cfg.network.resnet_hdim))
                fc_layers.append(nn.ELU(inplace=True))
            self.fc_head = nn.Sequential(*fc_layers)

        self.char_embeddings = nn.Embedding(256, self.cfg.network.char_edim)
        self.color_embeddings = nn.Embedding(128, self.cfg.network.color_edim)

        if self.cfg.network.resnet_use_pos_encodings:
            self.row_embeddings = nn.Embedding(obs_shape[0], self.cfg.network.resnet_pos_edim)
            self.col_embeddings = nn.Embedding(obs_shape[1], self.cfg.network.resnet_pos_edim)

        if self.cfg.network.fix_initialization:
            self.apply(self._init_weights)

    def _init_weights(self, module, scale: float = 1.0, bias: bool = True):
        if isinstance(module, nn.Linear):
            print('fixing resnet linear initialization ...')
            module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)

            if bias:
                module.bias.data *= 0

    def forward(self, chars, colors, glyphs):
        chars, colors, glyphs, groups, levels, difficulties, weights, speeds, corpses, sacrifices = self._embed(chars, colors, glyphs) # 21 x 80
        if self.cfg.network.add_char_color:
            glyphs = glyphs + groups + levels + difficulties + weights + speeds + corpses + sacrifices
            glyphs = torch.cat([glyphs, torch.zeros_like(glyphs)[..., -1:, :].to(glyphs.device)], dim=-2)
            x = chars + colors + glyphs
            x = x.permute(0, 1, 4, 2, 3).flatten(1, 2).contiguous()
        else:
            x = self._stack(chars, colors, glyphs, groups, levels, difficulties, weights, speeds, corpses, sacrifices)

        if self.cfg.network.resnet_use_pos_encodings:
            B, C, H, W = x.shape
            row_pos = torch.arange(H).unsqueeze(0).long().to(x.device)
            col_pos = torch.arange(W).unsqueeze(0).long().to(x.device)
            row_pos = selectt(self.row_embeddings, row_pos, True) # 1 x H x E
            col_pos = selectt(self.col_embeddings, col_pos, True) # 1 x W x E
            x = x + row_pos.permute(0, 2, 1).unsqueeze(-1) + col_pos.permute(0, 2, 1).unsqueeze(-2)
        
        x = self.conv_net(x)
        x = x.view(-1, self.out_size)
        x = self.fc_head(x)
        return x

    def _embed(self, chars, colors, glyphs):
        chars = selectt(self.char_embeddings, chars.long(), True)
        colors = selectt(self.color_embeddings, colors.long(), True)

        # get all features
        B, C, H, W = glyphs.shape
        ids = self.id_pairs_table.index_select(0, glyphs.view(-1).long())
        glyphs = ids.select(1, 0).view(B, C, H, W).long()
        groups = ids.select(1, 1).view(B, C, H, W).long()
        difficulties = ids.select(1, 2).view(B, C, H, W).long()
        levels = ids.select(1, 3).view(B, C, H, W).long()
        weights = ids.select(1, 4).view(B, C, H, W).long()
        speeds = ids.select(1, 5).view(B, C, H, W).long()
        corpses = ids.select(1, 6).view(B, C, H, W).long()
        sacrifices = ids.select(1, 7).view(B, C, H, W).long()

        # embed
        glyph_embs = selectt(self.glyph_emb, glyphs, True)
        group_embs = selectt(self.group_emb, groups, True)
        level_embs = selectt(self.level_emb, levels, True)
        difficulty_embs = selectt(self.difficulty_emb, difficulties, True)
        weight_embs = selectt(self.weight_emb, weights, True)
        speed_embs = selectt(self.speed_emb, speeds, True)
        corpse_embs = selectt(self.corpse_emb, corpses, True)
        sacrifice_embs = selectt(self.sacrifice_emb, sacrifices, True)

        return chars, colors, glyph_embs, group_embs, level_embs, difficulty_embs, weight_embs, speed_embs, corpse_embs, sacrifice_embs

    def _stack(self, chars, colors, glyphs, groups, levels, difficulties, weights, speeds, corpses, sacrifices):
        glyphs = torch.cat([glyphs, torch.zeros_like(glyphs)[..., -1:, :].to(glyphs.device)], dim=-2)
        groups = torch.cat([groups, torch.zeros_like(groups)[..., -1:, :].to(groups.device)], dim=-2)
        levels = torch.cat([levels, torch.zeros_like(levels)[..., -1:, :].to(levels.device)], dim=-2)
        difficulties = torch.cat([difficulties, torch.zeros_like(difficulties)[..., -1:, :].to(difficulties.device)], dim=-2)
        weights = torch.cat([weights, torch.zeros_like(weights)[..., -1:, :].to(weights.device)], dim=-2)
        speeds = torch.cat([speeds, torch.zeros_like(speeds)[..., -1:, :].to(speeds.device)], dim=-2)
        corpses = torch.cat([corpses, torch.zeros_like(corpses)[..., -1:, :].to(corpses.device)], dim=-2)
        sacrifices = torch.cat([sacrifices, torch.zeros_like(sacrifices)[..., -1:, :].to(sacrifices.device)], dim=-2)

        obs = torch.cat([chars, colors, glyphs, groups, levels, difficulties, weights, speeds, corpses, sacrifices], dim=-1)
        if self.glyph_embed_with_linear:
            obs = self.emb_linear(obs)

        return obs.permute(0, 1, 4, 2, 3).flatten(1, 2).contiguous()

class ResBlock(nn.Module):
    def __init__(self, cfg: DictConfig, in_channels: int, out_channels: int, filter_size: int, num_layers: int):
        super(ResBlock, self).__init__()

        self.cfg = cfg

        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    filter_size,
                    stride=1,
                    padding=(filter_size // 2)
                )
            )
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ELU(inplace=True))

        self.net = nn.Sequential(*layers)

        if self.cfg.network.fix_initialization:
            self.apply(self._init_weights)

    def _init_weights(self, module, scale: float = 1.0, bias: bool = True):
        if isinstance(module, nn.Conv2d):
            print('fixing resblock initialization ...')
            # NOTE: from https://github.com/openai/Video-Pre-Training/blob/077ba2b9885ff696051df8348dc760d9699139ca/lib/util.py#L68-L73
            # and https://github.com/openai/Video-Pre-Training/blob/077ba2b9885ff696051df8348dc760d9699139ca/lib/impala_cnn.py#L164-L168
            # and https://github.com/openai/Video-Pre-Training/blob/077ba2b9885ff696051df8348dc760d9699139ca/lib/impala_cnn.py#L105
            scale = math.sqrt(self.cfg.network.obs_conv_blocks) / math.sqrt(self.cfg.network.resnet_num_blocks)

            module.weight.data *= scale / module.weight.norm(
                dim=tuple(range(1, module.weight.data.ndim)), p=2, keepdim=True
            )

            # Init Bias
            if bias:
                module.bias.data *= 0

    def forward(self, x):
        return self.net(x) + x