# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""

import logging
from numpy import False_

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
    FairseqEncoder,
    FairseqEncoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder
from fairseq.modules import LayerNorm
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.models.roberta import RobertaEncoder, RobertaModel, base_architecture

from fairseq.models.xformer import S4Encoder, S4plusEncoder, S4_GLUEncoder

class RobertaS4Encoder(RobertaEncoder):
    """RoBERTa encoder."""

    def __init__(self, args, dictionary):
        super().__init__(args, dictionary)

    def build_encoder(self, args, dictionary, embed_tokens):
        encoder = S4Encoder(args, dictionary, embed_tokens)
        encoder.apply(init_bert_params)
        return encoder

@register_model("roberta_s4")
class RobertaS4(RobertaModel):
    def __init__(self, args, encoder):
        super().__init__(args, encoder)

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        encoder = RobertaS4Encoder(args, task.source_dictionary)
        return cls(args, encoder)

@register_model_architecture("roberta_s4", "roberta_s4_base")
def roberta_s4_base(args):
    base_architecture(args)
    args.d_state = 64
    args.bidirectional = True
    args.encoder_embed_dim = args.encoder_embed_dim * 2
    args.encoder_layers = 9
    args.no_token_positional_embeddings = True
    
class RobertaS4plusEncoder(RobertaEncoder):
    """RoBERTa encoder."""

    def __init__(self, args, dictionary):
        super().__init__(args, dictionary)

    def build_encoder(self, args, dictionary, embed_tokens):
        encoder = S4plusEncoder(args, dictionary, embed_tokens)
        encoder.apply(init_bert_params)
        return encoder

@register_model("roberta_s4plus")
class RobertaS4plus(RobertaModel):
    def __init__(self, args, encoder):
        super().__init__(args, encoder)

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        encoder = RobertaS4plusEncoder(args, task.source_dictionary)
        return cls(args, encoder)

@register_model_architecture("roberta_s4plus", "roberta_s4plus_base")
def roberta_s4plus_base(args):
    base_architecture(args)
    args.d_state = 64
    args.bidirectional = True
    args.encoder_embed_dim = args.encoder_embed_dim * 2
    args.encoder_layers = 4
    args.no_token_positional_embeddings = True


#S4_glu
class RobertaS4_GLUEncoder(RobertaEncoder):
    """RoBERTa encoder."""

    def __init__(self, args, dictionary):
        super().__init__(args, dictionary)

    def build_encoder(self, args, dictionary, embed_tokens):
        encoder = S4_GLUEncoder(args, dictionary, embed_tokens)
        encoder.apply(init_bert_params)
        return encoder

@register_model("roberta_s4_glu")
class RobertaS4_GLU(RobertaModel):
    def __init__(self, args, encoder):
        super().__init__(args, encoder)

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        encoder = RobertaS4_GLUEncoder(args, task.source_dictionary)
        return cls(args, encoder)

@register_model_architecture("roberta_s4_glu", "roberta_s4_glu_base")
def roberta_glu_s4_base(args):
    base_architecture(args)
    args.d_state = 64
    args.bidirectional = True
    args.encoder_embed_dim = args.encoder_embed_dim * 2
    args.encoder_layers = 6
    args.no_token_positional_embeddings = True
