#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
    FairseqEncoder,
    FairseqEncoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.modules.fairseq_dropout import FairseqDropout


default_conv_enc_config = """[
    (400, 13, 170, 0.2),
    (440, 14, 0, 0.214),
    (484, 15, 0, 0.22898),
    (532, 16, 0, 0.2450086),
    (584, 17, 0, 0.262159202),
    (642, 18, 0, 0.28051034614),
    (706, 19, 0, 0.30014607037),
    (776, 20, 0, 0.321156295296),
    (852, 21, 0, 0.343637235966),
    (936, 22, 0, 0.367691842484),
    (1028, 23, 0, 0.393430271458),
    (1130, 24, 0, 0.42097039046),
    (1242, 25, 0, 0.450438317792),
    (1366, 26, 0, 0.481969000038),
    (1502, 27, 0, 0.51570683004),
    (1652, 28, 0, 0.551806308143),
    (1816, 29, 0, 0.590432749713),
]"""


@register_model("asr_w2l_conv_glu_encoder")
class W2lConvGluEncoderModel(FairseqEncoderModel):
    def __init__(self, encoder):
        super().__init__(encoder)

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument(
            "--input-feat-per-channel",
            type=int,
            metavar="N",
            help="encoder input dimension per input channel",
        )
        parser.add_argument(
            "--in-channels",
            type=int,
            metavar="N",
            help="number of encoder input channels",
        )
        parser.add_argument(
            "--conv-enc-config",
            type=str,
            metavar="EXPR",
            help="""
    an array of tuples each containing the configuration of one conv layer
    [(out_channels, kernel_size, padding, dropout), ...]
            """,
        )

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
        encoder = W2lConvGluEncoder(
            vocab_size=len(task.target_dictionary),
            input_feat_per_channel=args.input_feat_per_channel,
            in_channels=args.in_channels,
            conv_enc_config=eval(conv_enc_config),
        )
        return cls(encoder)

    def get_normalized_probs(self, net_output, log_probs, sample=None):
        lprobs = super().get_normalized_probs(net_output, log_probs, sample)
        lprobs.batch_first = False
        return lprobs


class W2lConvGluEncoder(FairseqEncoder):
    def __init__(
        self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
    ):
        super().__init__(None)

        self.input_dim = input_feat_per_channel
        if in_channels != 1:
            raise ValueError("only 1 input channel is currently supported")

        self.conv_layers = nn.ModuleList()
        self.linear_layers = nn.ModuleList()
        self.dropouts = []
        cur_channels = input_feat_per_channel

        for out_channels, kernel_size, padding, dropout in conv_enc_config:
            layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
            layer.weight.data.mul_(math.sqrt(3))  # match wav2letter init
            self.conv_layers.append(nn.utils.weight_norm(layer))
            self.dropouts.append(
                FairseqDropout(dropout, module_name=self.__class__.__name__)
            )
            if out_channels % 2 != 0:
                raise ValueError("odd # of out_channels is incompatible with GLU")
            cur_channels = out_channels // 2  # halved by GLU

        for out_channels in [2 * cur_channels, vocab_size]:
            layer = nn.Linear(cur_channels, out_channels)
            layer.weight.data.mul_(math.sqrt(3))
            self.linear_layers.append(nn.utils.weight_norm(layer))
            cur_channels = out_channels // 2

    def forward(self, src_tokens, src_lengths, **kwargs):

        """
        src_tokens: padded tensor (B, T, C * feat)
        src_lengths: tensor of original lengths of input utterances (B,)
        """
        B, T, _ = src_tokens.size()
        x = src_tokens.transpose(1, 2).contiguous()  # (B, feat, T) assuming C == 1

        for layer_idx in range(len(self.conv_layers)):
            x = self.conv_layers[layer_idx](x)
            x = F.glu(x, dim=1)
            x = self.dropouts[layer_idx](x)

        x = x.transpose(1, 2).contiguous()  # (B, T, 908)
        x = self.linear_layers[0](x)
        x = F.glu(x, dim=2)
        x = self.dropouts[-1](x)
        x = self.linear_layers[1](x)

        assert x.size(0) == B
        assert x.size(1) == T

        encoder_out = x.transpose(0, 1)  # (T, B, vocab_size)

        # need to debug this -- find a simpler/elegant way in pytorch APIs
        encoder_padding_mask = (
            torch.arange(T).view(1, T).expand(B, -1).to(x.device)
            >= src_lengths.view(B, 1).expand(-1, T)
        ).t()  # (B x T) -> (T x B)

        return {
            "encoder_out": encoder_out,  # (T, B, vocab_size)
            "encoder_padding_mask": encoder_padding_mask,  # (T, B)
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
            1, new_order
        )
        encoder_out["encoder_padding_mask"] = encoder_out[
            "encoder_padding_mask"
        ].index_select(1, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return (1e6, 1e6)  # an arbitrary large number


@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
def w2l_conv_glu_enc(args):
    args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
    args.in_channels = getattr(args, "in_channels", 1)
    args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
