"""Encoder self-attention layer definition."""

import math
import pdb

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from vl_load.vita.model.multimodal_encoder.whale.module.layer.attention import (
    Conv1dLinear,
    MultiHeadedAttention,
    MultiLayeredConv1d,
    PositionalEncoding,
    PositionwiseFeedForward,
    RelPositionalEncoding,
)

# from vl_load.vita.model.multimodal_encoder.whale.module.component.utils import *
from vl_load.vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool


def repeat(N, fn):
    """Repeat module N times.

    :param int N: repeat time
    :param function fn: function to generate module
    :return: repeated modules
    :rtype: MultiSequential
    """
    return MultiSequential(*[fn(n) for n in range(N)])


class MultiSequential(torch.nn.Sequential):
    """Multi-input multi-output torch.nn.Sequential."""

    def forward(self, x, masks, pos_emb):

        """Repeat."""
        for m in self:
            x, masks, pos_emb = m(x, masks, pos_emb)
        return x, masks, pos_emb

    @torch.jit.export
    def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
        # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
        """Repeat."""
        for m in self:
            x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
                x, pos_emb, buffer, buffer_index, buffer_out
            )
        return x, pos_emb, buffer, buffer_index, buffer_out

    @torch.jit.export
    def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out):
        # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
        """Repeat."""
        for m in self:
            x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
                x, pos_emb, buffer, buffer_index, buffer_out
            )
            hidden_out.append(x)
        return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out


class TransformerLayer(nn.Module):
    """Transformer layer module.

    :param int size: input dim
    :param self_attn: self attention module
    :param feed_forward: feed forward module
    :param float dropout_rate: dropout rate
    :param bool normalize_before: whether to use layer_norm before the first block
    :param bool concat_after: whether to concat attention layer's input and output
        if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
        if False, no additional linear will be applied. i.e. x -> x + att(x)

    """

    def __init__(
        self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False
    ):
        """Construct an TransformerLayer object."""
        super(TransformerLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.norm1 = torch.nn.LayerNorm(size)
        self.norm2 = torch.nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.size = size
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear = nn.Linear(size + size, size)
        else:
            self.concat_linear = nn.Identity()

    @torch.jit.unused
    def forward(self, x, mask, pos_emb):
        """Compute encoded features.

        :param torch.Tensor x: encoded source features (batch, max_time_in, size)
        :param torch.Tensor mask: mask for x (batch, max_time_in)
        :rtype: Tuple[torch.Tensor, torch.Tensor]
        """
        residual = x
        if self.normalize_before:
            x = self.norm1(x)
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
            x = residual + self.concat_linear(x_concat)
        else:
            x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
        if not self.normalize_before:
            x = self.norm1(x)

        residual = x
        if self.normalize_before:
            x = self.norm2(x)
        x = residual + self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm2(x)

        return x, mask, pos_emb

    @torch.jit.export
    def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
        # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
        residual = x.clone()
        if self.normalize_before:
            x = self.norm1(x)
        if self.concat_after:
            x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
                x, x, x, pos_emb, buffer, buffer_index, buffer_out
            )
            x_concat = torch.cat((x, x_att), dim=-1)
            x = residual + self.concat_linear(x_concat)
        else:
            x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
                x, x, x, pos_emb, buffer, buffer_index, buffer_out
            )
            x = residual + x_att
        if not self.normalize_before:
            x = self.norm1(x)

        residual = x.clone()
        if self.normalize_before:
            x = self.norm2(x)
        x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(
            x, buffer, buffer_index, buffer_out
        )
        x = residual + x_feed
        if not self.normalize_before:
            x = self.norm2(x)

        return x, pos_emb, buffer, buffer_index, buffer_out


class Transformer(torch.nn.Module):
    @staticmethod
    def add_arguments(group):
        """Add TDNN common arguments."""
        group.add_argument(
            "--transformer-input-dim", default=256, type=int, help="Input dim of Transformer."
        )
        group.add_argument(
            "--transformer-output-dim", default=4, type=int, help="Output dim of Transformer."
        )
        group.add_argument(
            "--transformer-attention-dim", default=256, type=int, help="Dimention of attention."
        )
        group.add_argument(
            "--transformer-attention-heads",
            default=4,
            type=int,
            help="The number of heads of multi head attention.",
        )
        group.add_argument(
            "--transformer-linear-units",
            default=1024,
            type=int,
            help="The number of units of position-wise feed forward.",
        )
        group.add_argument(
            "--transformer-num-blocks", default=6, type=int, help="The number of attention blocks."
        )
        group.add_argument(
            "--transformer-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate in Transformer.",
        )
        group.add_argument(
            "--transformer-attention-dropout-rate",
            default=0.0,
            type=float,
            help="Dropout rate in attention.",
        )
        group.add_argument(
            "--transformer-positional-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate after adding positional encoding.",
        )
        group.add_argument(
            "--transformer-input-layer", default="linear", type=str, help="Type of input layer"
        )
        group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="")
        group.add_argument(
            "--transformer-normalize-before",
            default=True,
            type=strtobool,
            help="Whether to use layer-norm before the first block.",
        )
        group.add_argument(
            "--transformer-concat-after",
            default=False,
            type=strtobool,
            help="Whether to concat attention layer's input and output.",
        )
        group.add_argument(
            "--transformer-positionwise-layer-type",
            default="linear",
            type=str,
            help="Linear of conv1d.",
        )
        group.add_argument(
            "--transformer-positionwise-conv-kernel_size",
            default=1,
            type=int,
            help="Kernel size of positionwise conv1d layer.",
        )
        group.add_argument("--transformer-chunk_size", default=-1, type=int, help="")
        group.add_argument("--transformer-left_chunks", default=-1, type=int, help="")
        group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="")
        return group

    def __init__(
        self,
        args,
        input_dim=None,
        output_dim=None,
        attention_dim=None,
        attention_heads=None,
        linear_units=None,
        num_blocks=None,
        dropout_rate=None,
        positional_dropout_rate=None,
        attention_dropout_rate=None,
        input_layer=None,
        pos_enc_class=None,
        normalize_before=None,
        concat_after=None,
        positionwise_layer_type=None,
        positionwise_conv_kernel_size=None,
        chunk_size=None,
        left_chunks=None,
    ):
        """Construct an Encoder object."""
        super(Transformer, self).__init__()
        if args is None:
            self.input_dim = input_dim
            self.output_dim = output_dim
            self.attention_dim = attention_dim
            self.attention_heads = attention_heads
            self.linear_units = linear_units
            self.num_blocks = num_blocks
            self.dropout_rate = dropout_rate
            self.positional_dropout_rate = positional_dropout_rate
            self.attention_dropout_rate = attention_dropout_rate
            self.input_layer = input_layer
            self.pos_enc_class = pos_enc_class
            self.normalize_before = normalize_before
            self.concat_after = concat_after
            self.positionwise_layer_type = positionwise_layer_type
            self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
            self.chunk_size = chunk_size
            self.left_chunks = left_chunks
        else:
            self.input_dim = args.transformer_input_dim
            self.output_dim = args.transformer_output_dim
            self.attention_dim = args.transformer_attention_dim
            self.attention_heads = args.transformer_attention_heads
            self.linear_units = args.transformer_linear_units
            self.num_blocks = args.transformer_num_blocks
            self.dropout_rate = args.transformer_dropout_rate
            self.positional_dropout_rate = args.transformer_positional_dropout_rate
            self.attention_dropout_rate = args.transformer_attention_dropout_rate
            self.input_layer = args.transformer_input_layer
            self.pos_enc_class = args.transformer_pos_enc_class
            self.normalize_before = args.transformer_normalize_before
            self.concat_after = args.transformer_concat_after
            self.positionwise_layer_type = args.transformer_positionwise_layer_type
            self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
            self.chunk_size = args.transformer_chunk_size
            self.left_chunks = args.transformer_left_chunks
            self.transformer_dynamic_chunks = args.transformer_dynamic_chunks

        if self.pos_enc_class == "abs-enc":
            pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
            pos_enc_class = PositionalEncoding
        elif self.pos_enc_class == "rel-enc":
            pos_enc_args = (
                self.attention_dim,
                self.positional_dropout_rate,
                self.chunk_size,
                self.left_chunks,
            )
            pos_enc_class = RelPositionalEncoding

        if self.input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(self.input_dim, self.attention_dim),
                torch.nn.LayerNorm(self.attention_dim),
                torch.nn.Dropout(self.dropout_rate),
                torch.nn.ReLU(),
            )
        elif self.input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
            )
        elif self.input_layer == "none":
            self.embed = torch.nn.Sequential(torch.nn.Identity())
        else:
            raise ValueError("unknown input_layer: " + self.input_layer)
        self.pe = pos_enc_class(*pos_enc_args)
        self.embed_layer_num = len(self.embed)

        if self.positionwise_layer_type == "linear":
            positionwise_layer = PositionwiseFeedForward
            positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
        elif self.positionwise_layer_type == "conv1d":
            positionwise_layer = MultiLayeredConv1d
            positionwise_layer_args = (
                self.attention_dim,
                self.linear_units,
                self.positionwise_conv_kernel_size,
                self.dropout_rate,
            )
        elif self.positionwise_layer_type == "conv1d-linear":
            positionwise_layer = Conv1dLinear
            positionwise_layer_args = (
                self.attention_dim,
                self.linear_units,
                self.positionwise_conv_kernel_size,
                self.dropout_rate,
            )
        else:
            raise NotImplementedError("Support only linear or conv1d.")

        self.encoders = repeat(
            self.num_blocks,
            lambda lnum: TransformerLayer(
                self.attention_dim,
                MultiHeadedAttention(
                    self.attention_heads,
                    self.attention_dim,
                    self.attention_dropout_rate,
                    self.chunk_size,
                    self.left_chunks,
                    self.pos_enc_class,
                ),
                positionwise_layer(*positionwise_layer_args),
                self.dropout_rate,
                self.normalize_before,
                self.concat_after,
            ),
        )
        if self.normalize_before:
            self.after_norm = torch.nn.LayerNorm(self.attention_dim)

    @torch.jit.unused
    def forward(self, xs, ilens=None, masks=None):
        """Embed positions in tensor.

        :param torch.Tensor xs: input tensor
        :param torch.Tensor masks: input mask
        :return: position embedded tensor and mask
        :rtype Tuple[torch.Tensor, torch.Tensor]:
        """

        if self.transformer_dynamic_chunks == True:  # and self.training:
            chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1)
        else:
            chunk_masks = add_optional_chunk_mask(
                xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks
            ).to(xs.device)
        xs = self.embed(xs)
        xs, pos_emb = self.pe(xs)
        xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
        if self.normalize_before:
            xs = self.after_norm(xs)
        return xs, ilens, masks

    @torch.jit.export
    def infer(self, xs, buffer, buffer_index, buffer_out):
        xs = self.embed(xs)

        # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
        # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
        # buffer_out.append(pe_index.reshape(-1).to(torch.float32))
        # buffer_index = buffer_index + 1
        xs, pos_emb, _ = self.pe.infer(xs, 0)
        xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(
            xs, pos_emb, buffer, buffer_index, buffer_out
        )

        if self.normalize_before:
            xs = self.after_norm(xs)
        return xs, buffer, buffer_index, buffer_out

    @torch.jit.export
    def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out):
        xs = self.embed(xs)

        # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
        # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
        # buffer_out.append(pe_index.reshape(-1).to(torch.float32))
        # buffer_index = buffer_index + 1
        xs, pos_emb, _ = self.pe.infer(xs, 0)
        xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden(
            xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out
        )

        if self.normalize_before:
            xs = self.after_norm(xs)
        return xs, buffer, buffer_index, buffer_out, hidden_out
