from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module

from dynamic_architecture import DynamicArchitecture
from transformer_util.decoder import Decoder
from transformer_util.encoder import Encoder


# https://github.com/hyunwoongko/transformer/
class DynamicTransformer(DynamicArchitecture):

    def __init__(
        self,
        dim_input_features: int,
        dim_target: int,
        config: dict,
    ):
        super().__init__(dim_input_features, dim_target, config)
        self.max_len = dim_input_features
        self.d_model = config["embedding_dim"]

        # assume same tokenizer for encoding and decoding
        self.enc_vocab_size = dim_target
        self.dec_vocab_size = dim_target

        self.n_head = config["num_attention_heads"]

        self.ffn_hidden = config.get("num_hidden_neurons", None)

        self.n_layers = config["num_enc_dec_layers"]

        self.drop_prob = config["dropout"]

        self.encoder = Encoder(
            d_model=self.d_model,
            n_head=self.n_head,
            max_len=self.max_len,
            ffn_hidden=self.ffn_hidden,
            enc_voc_size=self.enc_vocab_size,
            drop_prob=self.drop_prob,
            n_layers=self.n_layers,
            dynamic=True,
            config=config,
        )

        self.decoder = Decoder(
            d_model=self.d_model,
            n_head=self.n_head,
            max_len=self.max_len,
            ffn_hidden=self.ffn_hidden,
            dec_voc_size=self.dec_vocab_size,
            drop_prob=self.drop_prob,
            n_layers=self.n_layers,
            dynamic=True,
            config=config,
        )

    def to(self, device):
        super().to(device)
        self.encoder.to(device)
        self.decoder.to(device)

    def get_layer(self, layer_id: int) -> Module:
        """
        Retrieves the layer to be modified.
        :param layer_id: id of the layer to retrieve
        :return: a Module object holding the layer
        """
        assert layer_id < self.n_layers * 2
        if layer_id >= self.n_layers:
            # print(f'Retrieving decoder layer {layer_id} == {layer_id % self.n_layers}')
            return self.decoder.get_layer(layer_id % self.n_layers)
        else:
            # print(f'Retrieving layer {layer_id} == encoder layer {layer_id % self.n_layers}')
            return self.encoder.get_layer(layer_id % self.n_layers)

    def set_layer(self, layer_id: int, layer):
        """
        Replace the current layer with the modified one
        :param layer_id: id of the layer to replace
        :param layer: a Module object holding the new layer
        :return:
        """
        assert layer_id < self.n_layers * 2
        if layer_id >= self.n_layers:
            return self.decoder.set_layer(layer_id % self.n_layers, layer)
        else:
            return self.encoder.set_layer(layer_id % self.n_layers, layer)

    def change_shape(
        self,
        layer_id: int,
        num_neurons: int,
        change_output: bool,
        neurons_probs: Optional[Tensor] = None,
    ):
        """
        Changes input or output dimension of the layer.
        :param layer_id: id of the layer to be modified
        :param num_neurons: the new number of neurons for the
            input or output dimension
        :param change_output: whether to change the output dimension.
            If false, it will change the input dimension
        :param neurons_probs: optional tensor of neuron probabilities
            for weight initialization
        :return:
        """
        # PATCH TO WORK WITH AWN, which assumes a simple structure
        if change_output:
            assert layer_id < self.n_layers * 2
            if layer_id >= self.n_layers:
                self.decoder.change_shape(
                    layer_id % self.n_layers,
                    num_neurons,
                    change_output,
                    neurons_probs,
                )
            else:
                self.encoder.change_shape(
                    layer_id % self.n_layers,
                    num_neurons,
                    change_output,
                    neurons_probs,
                )
        else:
            # we are already taking care of l+1 inside the transformer
            # since the layer index here is used to refer to either the encoder
            # or decoder layer, and not the internal FFN
            pass

    def forward(self, src_trg, qW_probs):
        assert len(qW_probs) == self.n_layers * 2

        tok, att_mask = src_trg[:, :, :, 0], src_trg[:, :, :, 1]
        src, trg = tok[:, :, 0], tok[:, :, 1]
        src_mask, trg_mask = att_mask[:, :, 0], att_mask[:, :, 1]

        # preprocess mask specific to this Transformer implementation
        src_mask = self.make_src_mask(src_mask)
        trg_mask = self.make_trg_mask(trg_mask, device=tok.device)

        enc_src = self.encoder(src, src_mask, qW_probs[: self.n_layers])
        output = self.decoder(
            trg, enc_src, trg_mask, src_mask, qW_probs[self.n_layers :]
        )
        return output, enc_src, None

    def make_src_mask(self, src_mask):
        # src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        src_mask = src_mask.unsqueeze(1).unsqueeze(2)
        return src_mask

    def make_trg_mask(self, trg_mask, device):
        # trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
        trg_pad_mask = trg_mask.unsqueeze(1).unsqueeze(3)
        trg_len = self.max_len  # trg.shape[1]
        trg_sub_mask = (
            torch.tril(torch.ones(trg_len, trg_len))
            .type(torch.ByteTensor)
            .to(device)
        )
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask

    def __len__(self):
        # NOTE: THIS WILL UNIQUELY DETERMINE THE NUMBER OF ADAPTIVE LAYERS
        #  INCLUDING ENCODER AND DECODER
        return self.n_layers * 2
