"""
TNF: Taking Notes on the Fly Helps Language Pretraining.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairseq import utils
from fairseq.models import (
    FairseqDecoder,
    FairseqLanguageModel,
    register_model,
    register_model_architecture,
)
from fairseq.modules import (
    LayerNorm,
    TnfTransformerSentenceEncoder,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params

from fairseq import distributed_utils

from .hub_interface import TnfHubInterface


@register_model('tnf')
class TnfModel(FairseqLanguageModel):

    @classmethod
    def hub_models(cls):
        return {
            # 'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
            # 'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
            # 'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
            # 'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz',
        }

    def __init__(self, args, encoder, source_dictionary):
        super().__init__(encoder)
        self.args = args

        # We follow BERT's random weight initialization
        self.apply(init_bert_params)
        if args.tnf_emb_zero_init == 1:
            self.decoder.tnf_embeddings.weight.data.zero_()
        else:
            self.decoder.tnf_embeddings.weight.data.normal_(mean=0.0, std=0.02)

        self.classification_heads = nn.ModuleDict()
        self.source_dictionary = source_dictionary


    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        parser.add_argument('--encoder-layers', type=int, metavar='L',
                            help='num encoder layers')
        parser.add_argument('--encoder-embed-dim', type=int, metavar='H',
                            help='encoder embedding dimension')
        parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='F',
                            help='encoder embedding dimension for FFN')
        parser.add_argument('--encoder-attention-heads', type=int, metavar='A',
                            help='num encoder attention heads')
        parser.add_argument('--activation-fn',
                            choices=utils.get_available_activation_fns(),
                            help='activation function to use')
        parser.add_argument('--pooler-activation-fn',
                            choices=utils.get_available_activation_fns(),
                            help='activation function to use for pooler layer')
        parser.add_argument('--encoder-normalize-before', action='store_true',
                            help='apply layernorm before each encoder block')
        parser.add_argument('--embedding-normalize', action='store_true',
                            help='add layernorm after the embedding layer')
        parser.add_argument('--dropout', type=float, metavar='D',
                            help='dropout probability')
        parser.add_argument('--attention-dropout', type=float, metavar='D',
                            help='dropout probability for attention weights')
        parser.add_argument('--activation-dropout', type=float, metavar='D',
                            help='dropout probability after activation in FFN')
        parser.add_argument('--pooler-dropout', type=float, metavar='D',
                            help='dropout probability in the masked_lm pooler layers')
        parser.add_argument('--max-positions', type=int,
                            help='number of positional embeddings to learn')
        parser.add_argument('--load-checkpoint-heads', action='store_true',
                            help='(re-)register and load heads when loading checkpoints')
        parser.add_argument('--tnf-emb-zero-init', type=int, metavar='A',
                            help='')
        parser.add_argument('--update-tnf-lambda', type=int, metavar='A',
                            help='')
        parser.add_argument('--update-tnf-emb', type=str, metavar='A',
                            help='')
        parser.add_argument('--ctx', type=str, metavar='A', default=None,
                            help='')
        parser.add_argument('--ctx-window-size', type=int, metavar='A', default=None,
                            help='')
        parser.add_argument('--fix-dict-shift', type=bool, metavar='A', default=False,
                            help='')
        parser.add_argument('--glue-tnf-bp', type=bool, metavar='A', default=False,
                            help='')

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not hasattr(args, 'max_positions'):
            args.max_positions = args.tokens_per_sample

        encoder = TnfEncoder(args, task.source_dictionary, task.tnf_source_dictionary)
        return cls(args, encoder, task.source_dictionary)

    def forward(self, src_tokens, tnf_src_tokens, tnf_src_tokens_nomask=None, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):

        if classification_head_name is not None:
            features_only = True

        # fix the dictionary shift bug for downstream tasks
        if self.args.fix_dict_shift is True:
            excluded_position = (src_tokens != self.source_dictionary.unk()) & (src_tokens != self.source_dictionary.pad()) & (src_tokens != self.source_dictionary.bos()) & (src_tokens != self.source_dictionary.eos())
            src_tokens[excluded_position] = src_tokens[excluded_position]-3
        x, extra = self.decoder(src_tokens, tnf_src_tokens, tnf_src_tokens_nomask, features_only, return_all_hiddens, **kwargs)

        if classification_head_name is not None:
            x = self.classification_heads[classification_head_name](x)
        return x, extra

    def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
        """Register a classification head."""
        if name in self.classification_heads:
            prev_num_classes = self.classification_heads[name].out_proj.out_features
            prev_inner_dim = self.classification_heads[name].dense.out_features
            if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
                print(
                    'WARNING: re-registering head "{}" with num_classes {} (prev: {}) '
                    'and inner_dim {} (prev: {})'.format(
                        name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
                    )
                )
        self.classification_heads[name] = RobertaClassificationHead(
            self.args.encoder_embed_dim,
            inner_dim or self.args.encoder_embed_dim,
            num_classes,
            self.args.pooler_activation_fn,
            self.args.pooler_dropout,
        )

    @property
    def supported_targets(self):
        return {'self'}

    @classmethod
    def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='gpt2', **kwargs):
        from fairseq import hub_utils
        x = hub_utils.from_pretrained(
            model_name_or_path,
            checkpoint_file,
            data_name_or_path,
            archive_map=cls.hub_models(),
            bpe=bpe,
            load_checkpoint_heads=True,
            **kwargs,
        )
        return TnfHubInterface(x['args'], x['task'], x['models'][0])

    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + '.' if name != '' else ''
        current_head_names = [] if not hasattr(self, 'classification_heads') else \
            self.classification_heads.keys()

        # Handle new classification heads present in the state dict.
        keys_to_delete = []
        for k in state_dict.keys():
            if not k.startswith(prefix + 'classification_heads.'):
                continue

            head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
            num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
            inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)

            if getattr(self.args, 'load_checkpoint_heads', False):
                if head_name not in current_head_names:
                    self.register_classification_head(head_name, num_classes, inner_dim)
            else:
                if head_name not in current_head_names:
                    print(
                        'WARNING: deleting classification head ({}) from checkpoint '
                        'not present in current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
                elif (
                    num_classes != self.classification_heads[head_name].out_proj.out_features
                    or inner_dim != self.classification_heads[head_name].dense.out_features
                ):
                    print(
                        'WARNING: deleting classification head ({}) from checkpoint '
                        'with different dimensions than current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
        for k in keys_to_delete:
            del state_dict[k]

        # Copy any newly-added classification heads into the state dict
        # with their current weights.
        if hasattr(self, 'classification_heads'):
            cur_state = self.classification_heads.state_dict()
            for k, v in cur_state.items():
                if prefix + 'classification_heads.' + k not in state_dict:
                    print('Overwriting', prefix + 'classification_heads.' + k)
                    state_dict[prefix + 'classification_heads.' + k] = v

    def before_update(self):
        self.decoder.before_update()

    def reset_state(self):
        self.decoder.reset_state()


class RobertaLMHead(nn.Module):
    """Head for masked language modeling."""

    def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
        super().__init__()
        self.dense = nn.Linear(embed_dim, embed_dim)
        self.activation_fn = utils.get_activation_fn(activation_fn)
        self.layer_norm = LayerNorm(embed_dim)

        if weight is None:
            weight = nn.Linear(embed_dim, output_dim, bias=False).weight
        self.weight = weight
        self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, features, masked_tokens=None, **kwargs):
        # Only project the unmasked tokens while training,
        # saves both memory and computation
        if masked_tokens is not None:
            features = features[masked_tokens, :]

        x = self.dense(features)
        x = self.activation_fn(x)
        x = self.layer_norm(x)
        # project back to size of vocabulary with bias
        x = F.linear(x, self.weight) + self.bias
        return x


class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.activation_fn = utils.get_activation_fn(activation_fn)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = self.activation_fn(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class TnfEncoder(FairseqDecoder):
    """Tnf encoder.

    Implements the :class:`~fairseq.models.FairseqDecoder` interface required
    by :class:`~fairseq.models.FairseqLanguageModel`.
    """

    def __init__(self, args, dictionary, tnf_dictionary):
        super().__init__(dictionary)
        self.args = args
        self.sentence_encoder = TnfTransformerSentenceEncoder(
            padding_idx=dictionary.pad(),
            vocab_size=len(dictionary),
            num_encoder_layers=args.encoder_layers,
            embedding_dim=args.encoder_embed_dim,
            ffn_embedding_dim=args.encoder_ffn_embed_dim,
            num_attention_heads=args.encoder_attention_heads,
            dropout=args.dropout,
            attention_dropout=args.attention_dropout,
            activation_dropout=args.activation_dropout,
            max_seq_len=args.max_positions,
            num_segments=0,
            encoder_normalize_before=args.encoder_normalize_before,
            embedding_normalize=args.embedding_normalize,
            apply_bert_init=True,
            activation_fn=args.activation_fn,
        )
        self.lm_head = RobertaLMHead(
            embed_dim=args.encoder_embed_dim,
            output_dim=len(dictionary),
            activation_fn=args.activation_fn,
            weight=self.sentence_encoder.embed_tokens.weight,
        )

        # tnf special tokens
        self.tnf_unk_idx = tnf_dictionary.unk()
        self.tnf_padding_idx = tnf_dictionary.pad()
        self.tnf_bos_idx = tnf_dictionary.bos()
        self.tnf_eos_idx = tnf_dictionary.eos()
        self.tnf_mask_idx = tnf_dictionary.indices['<mask>']
        self.tnf_embedding_dim = args.encoder_embed_dim

        # init tnf embeddings, tnf embeddings are detached from the computing graph while pre-traning.
        if args.update_tnf_lambda == 1:
            self.tnf_lambda = nn.Parameter(torch.zeros(1), requires_grad=True)
        else:
            self.tnf_lambda = args.tnf_lambda
        # if self.tnf_lambda is not None:
        #     assert (self.tnf_lambda <= 1 and self.tnf_lambda >= 0)
        self.tnf_gamma = args.tnf_gamma
        self.tnf_vocab_size = len(tnf_dictionary)

        self.tnf_embeddings = nn.Embedding(self.tnf_vocab_size, self.tnf_embedding_dim)

        self.tnf_embed_updates = torch.zeros_like(self.tnf_embeddings.weight.data)
        self.tnf_embed_updates_cnts = torch.zeros(self.tnf_vocab_size).float()
        self.tnf_words_updates_cnts = torch.zeros(self.tnf_vocab_size).float()
        if torch.cuda.is_available():
            self.tnf_embed_updates = self.tnf_embed_updates.cuda()
            self.tnf_embed_updates_cnts = self.tnf_embed_updates_cnts.cuda()
            self.tnf_words_updates_cnts = self.tnf_words_updates_cnts.cuda()
        if args.glue_tnf_bp is not True:
            self.tnf_embeddings.weight.requires_grad=False
        if hasattr(args, 'update_tnf_emb'):
            self.update_tnf_emb = args.update_tnf_emb
            self.ctx = args.ctx
            self.ctx_window_size = args.ctx_window_size
        else:
            self.update_tnf_emb = None

    def forward(self, src_tokens, tnf_src_tokens,tnf_src_tokens_nomask=None, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
        """
        Args:
            src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
            features_only (bool, optional): skip LM head and just return
                features. If True, the output will be of shape
                `(batch, src_len, embed_dim)`.
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            tuple:
                - the LM output of shape `(batch, src_len, vocab)`
                - a dictionary of additional data, where 'inner_states'
                  is a list of hidden states.
        """
        # make sure dtypes match
        # if self.tnf_embeddings.dtype != self.sentence_encoder.embed_tokens.weight.data.dtype:
        #     self.tnf_embeddings = self.tnf_embeddings.type_as(
        #         self.sentence_encoder.embed_tokens.weight.data
        #     )

        x, extra = self.extract_features(src_tokens, return_all_hiddens, tnf_src_tokens)
        # update tnf embeddings
        if self.update_tnf_emb is not None:
            self.update_tnf_prep(x, tnf_src_tokens, tnf_src_tokens_nomask)

        if not features_only:
            x = self.output_layer(x, masked_tokens=masked_tokens)
        return x, extra

    def extract_features(self, src_tokens, return_all_hiddens=False, tnf_src_tokens=None, **unused):
        # embed tnf tokens
        tnf_mask_position = (tnf_src_tokens == self.tnf_mask_idx)
        tnf_unk_position = (tnf_src_tokens == self.tnf_unk_idx)
        tnf_pad_position = (tnf_src_tokens == self.tnf_padding_idx)
        tnf_bos_position = (tnf_src_tokens == self.tnf_bos_idx)
        tnf_eos_position = (tnf_src_tokens == self.tnf_eos_idx)
        excluded_position = (
            tnf_mask_position | tnf_pad_position |
            tnf_unk_position | tnf_bos_position | tnf_eos_position
        )

        embed_tnf_tokens = (
            self.tnf_embeddings(tnf_src_tokens)
            if self.tnf_lambda is not None
            else None
        )

        inner_states, _ = self.sentence_encoder(
            src_tokens,
            last_state_only=not return_all_hiddens,
            tnf_lambda=self.tnf_lambda,
            embed_tnf_tokens=embed_tnf_tokens,
            excluded_position=excluded_position,
        )
        features = inner_states[-1]
        return features, {'inner_states': inner_states if return_all_hiddens else None}

    def output_layer(self, features, masked_tokens=None, **unused):
        return self.lm_head(features, masked_tokens)

    def max_positions(self):
        """Maximum output length supported by the encoder."""
        return self.args.max_positions

    def update_tnf_prep(self, x, tnf_src_tokens, tnf_src_tokens_nomask=None):
        """
            update tnf embeddings with final states.
        """

        if (self.tnf_lambda is not None) and self.training:
            if self.tnf_embed_updates.device != x.device:
                self.tnf_embed_updates = self.tnf_embed_updates.to(x.device)
                self.tnf_embed_update_cnts = self.tnf_embed_update_cnts.to(x.device)
                self.tnf_embed_updates = self.tnf_embed_updates.type_as(x)

            if self.update_tnf_emb == 'mask':
                tnf_src_tokens_positions = (tnf_src_tokens_nomask != self.tnf_unk_idx) & \
                    (tnf_src_tokens_nomask != self.tnf_bos_idx) & \
                    (tnf_src_tokens_nomask != self.tnf_eos_idx)
                tnf_src_tokens = tnf_src_tokens_nomask[tnf_src_tokens_positions]
            elif self.update_tnf_emb == 'default':
                tnf_src_tokens_positions = (tnf_src_tokens != self.tnf_unk_idx) & \
                    (tnf_src_tokens != self.tnf_bos_idx) & \
                    (tnf_src_tokens != self.tnf_eos_idx) & (tnf_src_tokens != self.tnf_mask_idx)
                tnf_src_tokens = tnf_src_tokens[tnf_src_tokens_positions]
            else:
                raise(ValueError)

            if tnf_src_tokens.size()[0] <= 0:
                return

            # detach x
            if self.ctx is not None:
                if self.ctx == 'cls':
                    x_detached = (x[:, 0, :].detach()
                                            .view(x.size(0), 1, -1)
                                            .expand(-1, x.size(1) ,-1))
                elif self.ctx == 'avg':
                    x_detached = (x.detach()
                                   .mean(dim=1)
                                   .view(x.size(0), 1, -1)
                                   .expand(-1, x.size(1) ,-1))
                elif self.ctx == 'tnfavg':
                     x_detached = (x.detach()[tnf_src_tokens_positions]
                                   .mean(dim=0)
                                   .view(1, 1, -1)
                                   .expand(x.size(0), x.size(1) ,-1))
                elif self.ctx == 'windowavg':
                    num_windows = x.size(1)//self.ctx_window_size
                    remainder = x.size(1) % self.ctx_window_size
                    if num_windows!=0:
                        x_detached_main = x.detach()[:,:(num_windows*self.ctx_window_size),:]
                        x_detached_main = x_detached_main.view(x.size(0), 
                            num_windows, self.ctx_window_size, -1)
                        x_detached_main = x_detached_main.mean(dim=2,keepdim=True).expand(-1, -1, self.ctx_window_size, -1)
                        x_detached_main = x_detached_main.reshape(x.size(0), num_windows*self.ctx_window_size, -1)
                        if remainder!=0:
                            x_remainder = x.detach()[:,(num_windows*self.ctx_window_size):,:]
                            x_remainder = x_remainder.mean(dim=1,keepdim=True).expand(-1, remainder, -1)
                            x_detached = torch.cat((x_detached_main, x_remainder),1)
                        else:
                            x_detached = x_detached_main
                    else:
                        x_remainder = x.detach()[:,(num_windows*self.ctx_window_size):,:]
                        x_detached = x_remainder.mean(dim=1,keepdim=True).expand(-1, remainder, -1)
                elif self.ctx == 'movingwindow':
                    if self.ctx_window_size >= x.size(1):
                        x_detached = x.detach().mean(dim=1).unsqueeze(dim=1).expand(-1, x.size(1), -1)
                    else:  
                        x_cumsum = torch.cumsum(x.detach(), dim=1)
                        x_sum = x_cumsum[:, self.ctx_window_size:, :] - x_cumsum[:, :-self.ctx_window_size, :]
                        if self.ctx_window_size % 2 == 0:
                            num_left = self.ctx_window_size // 2
                            num_right = num_left
                        else:
                            num_left = self.ctx_window_size // 2 + 1
                            num_right = num_left - 1                     
                        x_sum_left = x_cumsum[:, self.ctx_window_size-1, :].unsqueeze(dim=1).expand(-1, num_left, -1)
                        x_sum_right = x_sum[:, -1, :].unsqueeze(dim=1).expand(-1, num_right, -1)
                        x_detached = torch.cat([x_sum_left, x_sum, x_sum_right], dim=1)
                        x_detached = x_detached / self.ctx_window_size
                else:
                    raise(ValueError)
            else:
                x_detached = x.detach()
            sorted_tokens, sorted_index = torch.sort(tnf_src_tokens.view(-1))
            sorted_x = x_detached[tnf_src_tokens_positions, :].contiguous().view(-1, x.size(-1))[sorted_index]
            uni_tokens, inds, cnts = torch.unique_consecutive(sorted_tokens, return_inverse=True, return_counts=True)
            cum_x = torch.cumsum(sorted_x.float(), dim=0)
            ends = torch.cumsum(cnts, dim=0) - 1
            a = torch.index_select(cum_x, 0, ends)
            b = torch.cat([torch.zeros((1, x.size(-1)), dtype=a.dtype, device=a.device), a[:-1, :]], dim=0)
            uni_emb = (a - b)
            self.tnf_embed_updates[uni_tokens] += uni_emb.type_as(x)
            self.tnf_embed_updates_cnts[uni_tokens] += cnts.float()


    def before_update(self):
        if (self.tnf_lambda is not None) and self.training:
            def filter_based_on_cnt(cnts):
                update_tokens = cnts > 0
                return update_tokens
            if distributed_utils.get_world_size() > 1:
                # use fp32 to sync, to avoid int rounding error
                distributed_utils.all_reduce(self.tnf_embed_updates_cnts)
            self.tnf_words_updates_cnts += self.tnf_embed_updates_cnts
            update_tokens = filter_based_on_cnt(self.tnf_embed_updates_cnts)
            new_embs = self.tnf_embed_updates[update_tokens]
            if distributed_utils.get_world_size() > 1:
                # use fp16 to sync
                distributed_utils.all_reduce(new_embs)
            to_update = (new_embs / self.tnf_embed_updates_cnts[update_tokens].unsqueeze(1)).type_as(self.tnf_embeddings.weight.data)
            self.tnf_embeddings.weight.data[update_tokens] = (1.0 - self.tnf_gamma) * self.tnf_embeddings.weight.data[update_tokens] + self.tnf_gamma * to_update
            self.tnf_embed_updates.zero_()
            self.tnf_embed_updates_cnts.zero_()


    def reset_state(self):
        if self.tnf_lambda is not None:
            self.tnf_embed_updates.zero_()
            self.tnf_embed_updates_cnts.zero_()


@register_model_architecture('tnf', 'tnf')
def base_architecture(args):
    args.encoder_layers = getattr(args, 'encoder_layers', 12)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)

    args.activation_fn = getattr(args, 'activation_fn', 'gelu')
    args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')

    args.dropout = getattr(args, 'dropout', 0.1)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
    args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
    args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)

    args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
    args.embedding_normalize = getattr(args, 'embedding_normalize', False)


@register_model_architecture('tnf', 'tnf_base')
def tnf_base_architecture(args):
    base_architecture(args)


@register_model_architecture('tnf', 'tnf_small')
def tnf_small_architecture(args):
    args.encoder_layers = getattr(args, 'encoder_layers', 12)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
    base_architecture(args)


@register_model_architecture('tnf', 'tnf_large')
def tnf_large_architecture(args):
    args.encoder_layers = getattr(args, 'encoder_layers', 24)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
    base_architecture(args)


@register_model_architecture('tnf', 'tnf_L3A12')
def tnf_L3A12_architecture(args):
    args.encoder_layers = getattr(args, 'encoder_layers', 3)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
    base_architecture(args)


@register_model_architecture('tnf', 'tnf_L6A8')
def tnf_L6A8_architecture(args):
    args.encoder_layers = getattr(args, 'encoder_layers', 6)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
    base_architecture(args)
