#!/usr/bin/env python3

import logging
import math
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
    FairseqEncoder,
    register_model,
    register_model_architecture,
)
from fairseq.models.speech_to_text import (
    S2TTransformerEncoder,
    S2TTransformerModel,
    Conv1dSubsampler,
    TransformerDecoderScriptable,
)
from fairseq.models.transformer import Embedding, TransformerEncoder
from fairseq.modules import (
    FairseqDropout,
    PositionalEmbedding,
    GradMultiply,
    LayerNorm,
)

logger = logging.getLogger(__name__)


@register_model("s2t_dcm")
class S2TDCMModel(S2TTransformerModel):

    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)
        self.num_updates = 0

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        S2TTransformerModel.add_args(parser)

        parser.add_argument(
            "--text-encoder-layers",
            default=6,
            type=int,
            help="layers of the text encoder",
        )
        parser.add_argument(
            "--adapter",
            default="e2e",
            type=str,
            help="adapter type",
        )
        parser.add_argument(
            "--latent-temp",  # --temperature has been used in GenerationConfig https://github.com/facebookresearch/fairseq/blob/main/fairseq/dataclass/configs.py#L908 
            default="(1.0,0.5,0.99995)",  # w2v2 base: (2.0,0.5,0.999995)
            type=str,
            help="temperature config of the gumbel softmax",
        )
        parser.add_argument(
            "--enc-grad-mult",
            type=float,
            metavar="V",
            default=1.0,
            help="multiply speech enc and text enc gradient by V",
        )
        parser.add_argument(
            "--load-pretrained-acoustic-encoder-from",
            type=str,
            metavar="STR",
            help="model to take acoustic encoder weights from (for initialization)",
        )
        parser.add_argument(
            "--load-pretrained-text-encoder-from",
            type=str,
            metavar="STR",
            help="model to take text encoder weights from (for initialization)",
        )
        parser.add_argument(
            "--load-pretrained-decoder-from",
            type=str,
            metavar="STR",
            help="model to take decoder weights from (for initialization)",
        )
        parser.add_argument(
            "--text-input-cost-ratio",
            type=float,
            default=1.0,
            metavar="V",
            help="text input cost ratio relative to speech input cost",
        )
        parser.add_argument(
            "--add-speech-eos",
            action="store_true",
            help="add eos token at the end of input feature",
        )
        parser.add_argument(
            "--shrink-ctc",
            action="store_true",
            help="remove the blank tokens and average the repeated tokens of ctc sequence",
        )
        parser.add_argument(
            "--adapter-with-gumbel",
            action="store_true",
            help="whether to use gumbel sampling trick within adapter",
        )
        parser.add_argument(
            "--hard-prob",
            action="store_true",
            help="whether to use hard prob for adapter",
        )
        parser.add_argument(
            "--sample-path",
            action="store_true",
            help="ctc path is sampled from soft distribution rather than argmax",
        )

    @classmethod
    def build_encoder(cls, args, task, embed_tokens):
        encoder = S2TDCMEncoder(args, task, embed_tokens)

        if getattr(args, "load_pretrained_acoustic_encoder_from", None):
            encoder.acoustic_encoder = checkpoint_utils.load_pretrained_component_from_model(
                encoder.acoustic_encoder, args.load_pretrained_acoustic_encoder_from
            )
            logger.info(f"loaded pretrained speech encoder from: {args.load_pretrained_acoustic_encoder_from}")
        if getattr(args, "load_pretrained_text_encoder_from", None):
            encoder.text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                encoder.text_encoder, args.load_pretrained_text_encoder_from
            )
            logger.info(f"loaded pretrained text encoder from: {args.load_pretrained_text_encoder_from}")
        return encoder

    @classmethod
    def build_decoder(cls, args, task, embed_tokens):
        decoder = TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
        pretraining_path = getattr(args, "load_pretrained_decoder_from", None)
        if pretraining_path is not None:
            if not Path(pretraining_path).exists():
                logger.warning(
                    f"skipped pretraining because {pretraining_path} does not exist"
                )
            else:
                decoder = checkpoint_utils.load_pretrained_component_from_model(
                    component=decoder, checkpoint=pretraining_path
                )
                logger.info(f"loaded pretrained decoder from: {pretraining_path}")
        return decoder

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(
        self,
        src_tokens,
        src_lengths,
        prev_output_tokens,
        use_encoder_outputs=False,
        src_txt_tokens=None,
        src_txt_lengths=None,
        mode="sup_speech",
        **kwargs,
    ):
        """
        Run the forward pass for an encoder-decoder model.

        First feed a batch of source tokens through the encoder. Then, feed the
        encoder output and previous decoder outputs (i.e., teacher forcing) to
        the decoder to produce the next outputs::

            encoder_out = self.encoder(src_tokens, src_lengths)
            return self.decoder(prev_output_tokens, encoder_out)

        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (LongTensor): source sentence lengths of shape `(batch)`
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            mode = 'sup_speech' or 'text'

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        if mode == "text":
            assert src_txt_tokens is None
            src_txt_tokens = src_tokens
            src_txt_lengths = src_lengths
            src_tokens = None
            src_lengths = None
        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            src_txt_tokens=src_txt_tokens,
            src_txt_lengths=src_txt_lengths,
            **kwargs
        )
        if isinstance(encoder_out, tuple):  # training with mulitple input
            rst = []
            assert len(encoder_out) == 2
            for i, eo in enumerate(encoder_out):
                if i == 0 and "ctc_logits" in eo:
                    ctc_logits = eo["ctc_logits"][0]
                rst.append(
                    self.decoder(prev_output_tokens, eo)
                )
            dec_out = torch.cat([r[0] for r in rst], dim=0)
            decoder_out = (dec_out, None)
        else:
            decoder_out = self.decoder(prev_output_tokens, encoder_out)
        if use_encoder_outputs:
            return decoder_out, encoder_out
        else:
            return decoder_out


class S2TDCMEncoder(FairseqEncoder):

    def __init__(self, args, task=None, embed_tokens=None):
        super().__init__(None)
        self.acoustic_encoder = S2TTransformerEncoder(args, task, embed_tokens)
        self.adapter = Adapter(args, task.source_dictionary, embed_tokens)
        self.text_encoder = TextEncoder(args, task.source_dictionary, embed_tokens)

        self.enc_grad_mult = args.enc_grad_mult

        self.eos_num = 2 * len(args.conv_kernel_sizes.split(","))
        self.add_speech_eos = getattr(args, "add_speech_eos", False)
        self.eos_emb = (
            nn.Parameter(torch.zeros(1, args.input_feat_per_channel), requires_grad=True)
            if self.add_speech_eos and self.eos_num > 0
            else None
        )

        self.shrink_ctc = getattr(args, "shrink_ctc", False)
        temp = getattr(args, "latent_temp", "(1.0,0.5,0.99995)")
        if isinstance(temp, str):
            import ast
            temp = ast.literal_eval(temp)
        assert len(temp) == 3, f"{temp}, {len(temp)}"
        self.max_temp, self.min_temp, self.temp_decay = temp
        self.cur_temp = self.max_temp

        self.use_gumbel = getattr(args, "adapter_with_gumbel", False) and getattr(args, "adapter", "e2e") in ["cascade", "e2e"]
        self.hard_prob = getattr(args, "hard_prob", False)
        self.sample_path = getattr(args, "sample_path", False)

    def mult_rst_grad(self, rst, ratio):
        assert isinstance(rst, dict)  # instead of EncoderOut
        assert len(rst["encoder_out"]) == 1
        rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio)
        return rst

    def add_speech_eos_tokens(self, src_tokens, src_lengths):
        bsz, max_seq_len, fdim = src_tokens.size()
        if self.eos_num > 0:
            src_token_eos = torch.zeros(
                [bsz, max_seq_len + self.eos_num, fdim],
                dtype=src_tokens.dtype,
                device=src_tokens.device,
            )
            src_token_eos[:, :max_seq_len] = src_tokens
            for bi in range(bsz):
                src_token_eos[bi][
                    src_lengths[bi] : src_lengths[bi] + self.eos_num
                ] = self.eos_emb.expand(self.eos_num, fdim)
            src_lengths = src_lengths + self.eos_num
            src_tokens = src_token_eos
        return src_tokens, src_lengths

    def shrink_ctc_sequence(self, enc_out, ctc_prob, blank_id=0, pad_id=1):
        # inputs are T x B x C/V
        with torch.no_grad():
            if self.sample_path and self.training:
                best_path = torch.distributions.categorical.Categorical(ctc_prob.transpose(0, 1)).sample()
            else:
                best_path = torch.argmax(ctc_prob.transpose(0, 1), dim=-1)  # B x T
            batch_size, ctc_max_len = best_path.size()

            non_blank_ids = []
            unique_path = []
            info = []
            for p in best_path:
                # remove duplicate tokens
                unique_p, dup_ids, cnt = torch.unique_consecutive(p, return_inverse=True, return_counts=True)
                w = 1.0 / torch.gather(cnt, dim=0, index=dup_ids)
                info.append((dup_ids, w))
                unique_path.append(unique_p)

                # remove blank and pad
                _ids = torch.logical_and(unique_p != blank_id, unique_p != pad_id).nonzero(as_tuple=True)[0]
                if len(_ids) == 0:
                    # train data is noisy. corner case: ctc seq is empty. At least add one token
                    _ids = torch.as_tensor([0], device=_ids.device)
                non_blank_ids.append(_ids)

            unique_max_len = max([len(unique_p) for unique_p in unique_path])
            avg_weight = torch.zeros((batch_size, ctc_max_len, unique_max_len), dtype=ctc_prob.dtype, device=ctc_prob.device)  # B x T x T'
            for b, (dup_ids, w) in enumerate(info):
                avg_weight[b].scatter_(1, dup_ids.unsqueeze(0).transpose(0, 1), w.unsqueeze(0).transpose(0, 1))

            padded_non_blank_ids = pad_sequence(non_blank_ids, batch_first=True, padding_value=unique_max_len-1)  # B x T"

            shrink_lengths = [len(_ids) for _ids in non_blank_ids]
            shrink_max_len = max(shrink_lengths)
            shrink_padding_mask = lengths_to_padding_mask(
                torch.as_tensor(shrink_lengths, device=ctc_prob.device)
            )

        def avg_duplicate(x, weight):
            # x: T x B x C/V
            # weight: B x T x T'
            # return: T' x B x C/V 
            return x.permute(1, 2, 0).bmm(weight).permute(2, 0, 1)

        def remove_blank(x, ids):
            # x: T' x B x C/V
            # ids: B x T"
            # return: T" x B x C/V
            _, bs, last_dim = x.size()
            return torch.gather(
                x,
                dim=0,
                index=ids.transpose(0,1).unsqueeze(-1).expand(shrink_max_len, bs, last_dim),
            )

        unique_enc_out = avg_duplicate(enc_out, avg_weight)
        unique_ctc_prob = avg_duplicate(ctc_prob, avg_weight)
        shrink_enc_out = remove_blank(unique_enc_out, padded_non_blank_ids)
        shrink_ctc_prob = remove_blank(unique_ctc_prob, padded_non_blank_ids)

        onehot_shrink_path = None
        if self.hard_prob: 
            padded_unique_path = pad_sequence(unique_path, batch_first=True, padding_value=pad_id)  # B x T'
            onehot_unique_path = torch.zeros_like(unique_ctc_prob, memory_format=torch.legacy_contiguous_format).scatter_(-1, padded_unique_path.transpose(0, 1).unsqueeze(-1), 1.0)
            # straight through
            if self.training:
                onehot_unique_path = onehot_unique_path - unique_ctc_prob.detach() + unique_ctc_prob
            onehot_shrink_path = remove_blank(onehot_unique_path, padded_non_blank_ids)

        return shrink_enc_out, shrink_ctc_prob, shrink_padding_mask, onehot_shrink_path

    def gumbel_reparametrization(self, logits, tau, dim=-1):
        gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
        gumbels = (logits + gumbels) / tau
        return gumbels.softmax(dim)

    def set_num_updates(self, num_updates):
        super().set_num_updates(num_updates)
        if self.use_gumbel:
            self.cur_temp = max(
                self.max_temp * self.temp_decay**num_updates, self.min_temp
            )

    def add_eos_before_shrink_ctc(self, enc_out, ctc_prob, eos_id=2):
        """
        For NMT pretraining, the eos_id (feature) by default is appended to the end of source tokens (features).

        The i-th batch level enc_out (T x C) has implicit ctc tokens (derived by ctc_prob, T x V)
          e.g, [4,4,0,5,0,0,6,6,1,1], where blank=0, pad=1
        A standard way to add eos_id (feature) is to insert it before pads, see add_eos()
          e.g., [4,4,0,5,0,0,6,6,2,1,1], where eos=2
        However, this is not efficient for batch level features (loop required for each example in the batch)

        So I hack it by add eos to the end, which only works for the case when applying shrink ctc afterwards, see shrink_ctc_sequence().
          e.g., [4,4,0,5,0,0,6,6,1,1,2]
        After removing duplicate tokens, we can obtain
          e.g., [4,0,5,0,6,1,2]
        After removing blank and pad, we can obatin
          e.g., [4,5,6,2]
        Then, we will add pads to the end until batch max length.
          e.g., [4,5,6,2,1,1,....]

        Input: T x B x C/V
        Return: (T + 1) x B x C/V
        """
        _, batch_size, vocab_size = ctc_prob.size()
        eos_token = torch.as_tensor([eos_id], device=enc_out.device)
        eos_out = self.text_encoder.embed_tokens(eos_token).expand(1, batch_size, -1)
        enc_out = torch.cat([enc_out, eos_out], dim=0) 
        eos_prob = F.one_hot(eos_token, num_classes=vocab_size).type(ctc_prob.dtype).expand(1, batch_size, -1) 
        ctc_prob = torch.cat([ctc_prob, eos_prob], dim=0)
        return enc_out, ctc_prob

    def add_eos(self, enc_out, ctc_prob, blank_id=0, pad_id=1, eos_id=2):
        """
        General method for adding eos
          e.g., [4,4,0,5,0,0,6,6,1,1] as implicit ctc tokens
        after adding eos, we have
          e.g., [4,4,0,5,0,0,6,6,2,1,1]
        """
        max_len, batch_size, vocab_size = ctc_prob.size()
        _, _, dim_size = enc_out.size()
        eos_token = torch.as_tensor([eos_id], device=enc_out.device)
        eos_out = self.text_encoder.embed_tokens(eos_token) # 1 x C
        eos_prob = F.one_hot(eos_token, num_classes=vocab_size).type(ctc_prob.dtype) # 1 x V
        enc_out_eos = torch.zeros([max_len + 1, batch_size, dim_size],
                                  dtype=enc_out.dtype,
                                  device=enc_out.device)
        ctc_prob_eos = torch.zeros([max_len + 1, batch_size, vocab_size],
                                   dtype=ctc_prob.dtype,
                                   device=ctc_prob.device)
        for b in range(batch_size):
            cur_prob = ctc_prob[:, b] # T x V
            cur_path = torch.argmax(cur_prob, dim=-1)
            insert_pos = torch.logical_and(cur_path != blank_id, cur_path != pad_id).nonzero(as_tuple=True)[0][-1] + 1
            enc_out_eos[:insert_pos, b] = enc_out[:insert_pos, b]
            enc_out_eos[insert_pos, b] = eos_out 
            enc_out_eos[(insert_pos+1):, b] = enc_out[insert_pos:, b]
            ctc_prob_eos[:insert_pos, b] = ctc_prob[:insert_pos, b]
            ctc_prob_eos[insert_pos, b] = eos_prob               
            ctc_prob_eos[(insert_pos+1):, b] = ctc_prob[insert_pos:, b]
        return enc_out_eos, ctc_prob_eos

    def forward(
        self,
        src_tokens,
        src_lengths=None,
        src_txt_tokens=None,
        src_txt_lengths=None,
        **kwargs
    ):
        """
        Args:
            src_tokens: padded tensor (B, T, C * feat)
            src_lengths: tensor of original lengths of input utterances (speech) (B,)
            src_txt_tokens: padded tensor (B, T)
            src_txt_lengths: tensor of original lengths of input utterances (text) (B,)
        """
        if src_tokens is None and src_txt_tokens is None:
            raise ValueError(
                "src_tokens and src_txt_tokens cannot be None at the same time"
            )
        ret1 = None
        ret2 = None
        return_all_hiddens = False
        if src_tokens is not None:
            if self.add_speech_eos:
                src_tokens, src_lengths = self.add_speech_eos_tokens(src_tokens, src_lengths)

            ret1 = self.acoustic_encoder(
                src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
            )
            encoder_out = ret1["encoder_out"][0]
            encoder_padding_mask = ret1["encoder_padding_mask"][0]
            # ctc
            ctc_logits = self.acoustic_encoder.compute_ctc_logits(encoder_out)
            if self.use_gumbel and self.training:
                ctc_probs = self.gumbel_reparametrization(ctc_logits, self.cur_temp, dim=-1)
            else:
                ctc_probs = F.softmax(ctc_logits / self.cur_temp, dim=-1)

            ctc_padding_mask = encoder_padding_mask
            onehot_shrink_path = None
            if self.shrink_ctc:
                # when applying shrink ctc, we can add eos here efficiently. No need to add_speech_eos_tokens() with for loop.
                encoder_out, ctc_probs = self.add_eos_before_shrink_ctc(encoder_out, ctc_probs) 
                encoder_out, ctc_probs, encoder_padding_mask, onehot_shrink_path = self.shrink_ctc_sequence(encoder_out, ctc_probs)

            # adapter
            x, encoder_padding_mask = self.adapter(encoder_out, ctc_probs, encoder_padding_mask, onehot_shrink_path)

            # text encoder
            x = self.text_encoder.forward_after_embedding(x, encoder_padding_mask)
            ret1["encoder_out"] = [x]
            ret1["encoder_padding_mask"] = [encoder_padding_mask]
            ret1["ctc_logits"] = [ctc_logits]
            ret1["ctc_padding_mask"] = [ctc_padding_mask]

        if src_txt_tokens is not None:
            ret2 = self.text_encoder(
                src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens
            )

        def merge_output(rst1, rst2):
            if rst1 is None:
                return rst2
            if rst2 is None:
                return rst1
            if self.enc_grad_mult != 1.0 and self.training:
                rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult)
                rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult)
            rst = (rst1, rst2)
            return rst

        return merge_output(ret1, ret2)

    def reorder_encoder_out(self, encoder_out, new_order):
        assert self.training is False  # used for inference only
        if isinstance(encoder_out, tuple):
            # this is for valid during training
            encoder_out = encoder_out[0]    # extract speech encoder_out
            encoder_out["encoder_states"] = []

        new_encoder_out = (
            [] if len(encoder_out["encoder_out"]) == 0
            else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
        )

        new_encoder_padding_mask = (
            [] if len(encoder_out["encoder_padding_mask"]) == 0
            else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
        )

        new_encoder_embedding = (
            [] if len(encoder_out["encoder_embedding"]) == 0
            else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
        )

        src_tokens = (
            [] if len(encoder_out["src_tokens"]) == 0
            else [x.index_select(0, new_order) for x in encoder_out["src_tokens"]]
        )

        src_lengths = (
            [] if len(encoder_out["src_lengths"]) == 0
            else [x.index_select(0, new_order) for x in encoder_out["src_lengths"]]
        )

        encoder_states = encoder_out["encoder_states"]
        if len(encoder_states) > 0:
            for idx, state in enumerate(encoder_states):
                encoder_states[idx] = state.index_select(1, new_order)

        return {
            "encoder_out": new_encoder_out,  # T x B x C
            "encoder_padding_mask": new_encoder_padding_mask,  # B x T
            "encoder_embedding": new_encoder_embedding,  # B x T x C
            "encoder_states": encoder_states,  # List[T x B x C]
            "src_tokens": src_tokens,  # B x T
            "src_lengths": src_lengths,  # B x 1
        }


class Adapter(nn.Module):

    def __init__(self, args, dictionary, embed_tokens):
        super().__init__()

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )

        self.adapter_type = getattr(args, "adapter", "e2e")
        self.use_gumbel = getattr(args, "adapter_with_gumbel", False) and self.adapter_type in ["e2e", "cascade"]
        self.hard_prob = getattr(args, "hard_prob", False)

        if self.adapter_type == "e2e":
            self.linear_adapter = nn.Sequential(
                nn.Linear(embed_dim, embed_dim),
                LayerNorm(args.encoder_embed_dim),
                self.dropout_module,
                nn.ReLU(),
            )

        if embed_tokens is None:
            self.embed_adapter = Embedding(len(dictionary), embed_dim, self.padding_idx)
        else:
            self.embed_adapter = embed_tokens

        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, args.encoder_embed_dim, self.padding_idx,
        )

    def forward(self, representation, distribution, padding, onehot=None):
        lengths = (~padding).long().sum(-1)

        if self.hard_prob:
            distribution = onehot

        if self.adapter_type == "cascade":
            out = torch.matmul(distribution, self.embed_adapter.weight)
        elif self.adapter_type == "e2e":
            linear_out = self.linear_adapter(representation)
            emb_out = torch.matmul(distribution, self.embed_adapter.weight)
            out = linear_out + emb_out
        else:
            out = None
            logging.error("Unsupported adapter type: {}.".format(self.adapter_type))

        out = self.embed_scale * out
        positions = self.embed_positions(padding).transpose(0, 1)
        out = positions + out

        out = self.dropout_module(out)

        return out, padding


class TextEncoder(TransformerEncoder):

    def forward_after_embedding(self, x, encoder_padding_mask):
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return x


@register_model_architecture(model_name="s2t_dcm", arch_name="s2t_dcm")
def base_architecture(args):
    args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0)
    # Convolutional subsampler
    args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
    args.conv_channels = getattr(args, "conv_channels", 1024)
    # Transformer
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
    args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
    args.decoder_ffn_embed_dim = getattr(
        args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
    )
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
    args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
    args.dropout = getattr(args, "dropout", 0.1)
    args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
    args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
    args.activation_fn = getattr(args, "activation_fn", "relu")
    args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
    args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
    args.share_decoder_input_output_embed = getattr(
        args, "share_decoder_input_output_embed", False
    )
    args.no_token_positional_embeddings = getattr(
        args, "no_token_positional_embeddings", False
    )
    args.adaptive_input = getattr(args, "adaptive_input", False)
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
    args.decoder_output_dim = getattr(
        args, "decoder_output_dim", args.decoder_embed_dim
    )
    args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
    args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
    args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
    # dcm
    args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
    args.adapter = getattr(args, "adapter", "e2e")
    args.temperature = getattr(args, "temperature", 1.0)
    args.enc_grad_mult = getattr(args, "enc_grad_mult", 1.0)


@register_model_architecture("s2t_dcm", "s2t_dcm_s")
def s2t_dcm_s(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
    args.dropout = getattr(args, "dropout", 0.1)
    base_architecture(args)


@register_model_architecture("s2t_dcm", "s2t_dcm_m")
def s2t_dcm_m(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
    args.dropout = getattr(args, "dropout", 0.15)
    base_architecture(args)
