from email.policy import default
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.models import (
    register_model,
    register_model_architecture,
)

import torch
import torch.nn as nn
from torch import Tensor

import time

import logging

from fairseq import utils

from fairseq.models.transformer.transformer_base import (
    TransformerModelBase,
)

from fairseq.models.transformer.transformer_config import (
    TransformerConfig,
    DEFAULT_MAX_SOURCE_POSITIONS,
    DEFAULT_MAX_TARGET_POSITIONS,
    DEFAULT_MIN_PARAMS_TO_WRAP,
)

from .byte_subword_embedding import (
    ByteSubwordEmbed,
)

from .byte_subword_one_embedding import (
    ByteSubwordOneEmbed,
)


from .byte_subword_custom import (
    ByteSubword,
)

from .byte_subword_combine import (
    ByteSubwordCombine,
)

from .byte_attention_subword import (
    ByteSubwordAttention
)

from .byte_subword_concate import (
    ByteSubwordConcate
)

from .byte_subword_concate_onehot import (
    ByteSubwordConcateOnehot
)

from .byte_decoder import (
    ByteWordTransformerDecoder,
)

from .byte_subword_config import ByteSubwordTransformerConfig

logger = logging.getLogger(__name__)

@register_model("byteword_transformer")
class TransformerModelByteWord(TransformerModelBase):
    """
    This is legacy implementation of the transformer model that
    uses argparse for configuration
    """

    def __init__(self, args, encoder, decoder):
        cfg = ByteSubwordTransformerConfig.from_namespace(args)
        # cfg = TransformerConfig.from_namespace(args)
        super().__init__(cfg, encoder, decoder)
        self.args = args

    @classmethod
    def add_args(cls, parser):
        """Add model-specific arguments to the parser."""

        gen_parser_from_dataclass(
            parser, ByteSubwordTransformerConfig(), delete_default=True, with_prefix=""
        )

        # gen_parser_from_dataclass(
        #     parser, TransformerConfig(), delete_default=True, with_prefix=""
        # )

        # gen_parser_from_dataclass(
        #     parser, TransformerConfig(), delete_default=False, with_prefix=""
        # )
        

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        byteword_base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        elif args.merge_src_tgt_embed:
            logger.info(f"source dict size: {len(src_dict)}")
            logger.info(f"target dict size: {len(tgt_dict)}")
            src_dict.update(tgt_dict)
            task.src_dict = src_dict
            task.tgt_dict = src_dict
            logger.info(f"merged dict size: {len(src_dict)}")
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = cls.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )
        if args.offload_activations:
            args.checkpoint_activations = True  # offloading implies checkpointing
        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)

        # measure time
        # torch.cuda.synchronize()
        # start_embed = time.time()
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        # torch.cuda.synchronize()
        # end_embed = time.time()
        # elapsed = end_embed - start_embed
        # print(args.alone_one_emb, elapsed)

        return cls(args, encoder, decoder)

        # if getattr(args, "offload_activations", False):
        #     args.checkpoint_activations = True  # offloading implies checkpointing

        # if not args.share_all_embeddings:
        #     args.min_params_to_wrap = getattr(
        #         args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
        #     )
        # cfg = AloneTransformerConfig.from_namespace(args)
        # return super().build_model(cfg, task)

    @classmethod
    def build_embedding(cls, args, dictionary, embed_dim, path=None):
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()

        # print(args.alone_mask_file)
        # exit()

        # measure time
        # torch.cuda.synchronize()
        # start_embed = time.time()


        # print(args)
        # exit()


        if args.bw_byte_subword == -1:
            emb = Embedding(num_embeddings, embed_dim, padding_idx)

        else:
            # num_embeddings, embedding_dim, padding_idx, one_emb_type, dropout, std,
            # codenum, codebooknum, layernum, interdim, relu_dropout, mask_file
            # emb = ByteWordEmbedding(num_embeddings, embed_dim, args.bw_padding_idx, args.bw_layernum, args.bw_interdim, args.bw_dropout, 
            #                        args.bw_std, args.bw_relu_dropout, args.bw_aggre,
            #                       args.bw_subword_bytes_file)

            emb = ByteWordEmbedding(num_embeddings, embed_dim, args.bw_interdim, args.bw_relu_dropout, args.bw_aggre, args.bw_subword_bytes_file)
        # if provided, load from preloaded dictionaries


        # torch.cuda.synchronize()
        # end_embed = time.time()
        # elapsed = end_embed - start_embed
        # print(args.alone_one_emb, elapsed)


        if path:
            embed_dict = utils.parse_embedding(path)
            utils.load_embedding(embed_dict, dictionary, emb)
        return emb

    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        return super().build_encoder(
            TransformerConfig.from_namespace(args), src_dict, embed_tokens
        )

    # @classmethod
    # def build_decoder(cls, args, tgt_dict, embed_tokens):
    #     return super().build_decoder(
    #         TransformerConfig.from_namespace(args), tgt_dict, embed_tokens
    #     )

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return ByteWordTransformerDecoder(
            args,
            tgt_dict,
            embed_tokens,
            no_encoder_attn=getattr(args, "no_cross_attention", False),
        )

def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m

def ByteWordEmbedding(num_embeddings, embedding_dim, interdim, relu_dropout, aggre, subword_bytes_file):
    # m = ByteSubwordEmbed(num_embeddings, embedding_dim, padding_idx, layernum, interdim, std, relu_dropout, aggre, subword_bytes_file)
    # m = ByteSubwordOneEmbed(num_embeddings, embedding_dim, padding_idx, layernum, interdim, std, relu_dropout, aggre, subword_bytes_file)
    # nn.init.normal_(m.weight, mean=0, std=(embedding_dim) ** -0.4)
    # for layer in m.linears:
    #     nn.init.xavier_uniform_(layer.weight)
    # for i in range(layernum):
    #     nn.init.xavier_uniform_(m.weight_matrices[i])

    # # USE THE BYTE TO SUBWORD SETTING
    # m = ByteSubwordEmbed(num_embeddings, embedding_dim, padding_idx, layernum, interdim, std, relu_dropout, aggre, subword_bytes_file)
    # nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    # nn.init.constant_(m.weight[padding_idx], 0)
    # for layer in m.linears:
    #     nn.init.xavier_uniform_(layer.weight)

    # # USE THE CUSTOMED BYTE TO SUBWORD SETTING
    # m = ByteSubword(num_embeddings, embedding_dim, padding_idx, layernum, interdim, std, relu_dropout, aggre, subword_bytes_file)
    # nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    # # nn.init.constant_(m.weight[padding_idx], 0)
    # for layer in m.linears:
    #     nn.init.xavier_uniform_(layer.weight)

    # # USE THE CUSTOMED BYTE TO SUBWORD SETTING with attention
    # m = ByteSubwordAttention(num_embeddings, embedding_dim, interdim, relu_dropout, aggre, subword_bytes_file)
    # nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    # nn.init.normal_(m.W_a, mean=0, std=embedding_dim ** -0.5)
    # # nn.init.constant_(m.weight[padding_idx], 0)
    # for layer in m.linears:
    #     nn.init.xavier_uniform_(layer.weight)


    # m = ByteSubwordConcate(num_embeddings, embedding_dim, interdim, relu_dropout, aggre, subword_bytes_file)
    # nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    # # nn.init.constant_(m.weight[padding_idx], 0)
    # for layer in m.linears:
    #     nn.init.xavier_uniform_(layer.weight)

    # print("transformer working here")

    m = ByteSubwordConcateOnehot(num_embeddings, embedding_dim, interdim, relu_dropout, aggre, subword_bytes_file)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    # nn.init.constant_(m.weight[padding_idx], 0)
    for layer in m.linears:
        nn.init.xavier_uniform_(layer.weight)

    # USE THE COMBINATION OF WORD EMBEDDING AND BYTE EMBEDDING TO REPRESENT A WORD
    # m = ByteSubwordCombine(num_embeddings, embedding_dim, padding_idx, layernum, interdim, std, relu_dropout, aggre, subword_bytes_file)
    # nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    # nn.init.constant_(m.weight[padding_idx], 0)
    # nn.init.normal_(m.embedding.weight, mean=0, std=embedding_dim ** -0.5)
    # nn.init.constant_(m.embedding.weight[1], 0)
    # for layer in m.linears:
    #     nn.init.xavier_uniform_(layer.weight)

    return m



@register_model_architecture("byteword_transformer", "byteword_transformer_tiny")
def byteword_transformer_tiny(args):
    # print(args)
    # print("byte embed dim", args.encoder_embed_dim)
    # exit()
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64)
    args.encoder_layers = getattr(args, "encoder_layers", 2)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
    args.decoder_layers = getattr(args, "decoder_layers", 2)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
    return byteword_base_architecture(args)

# architectures

@register_model_architecture("byteword_transformer", "byteword_transformer")
def byteword_base_architecture(args):
    args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
    args.encoder_layers = getattr(args, "encoder_layers", 6)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)

    args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
    args.decoder_ffn_embed_dim = getattr(
        args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
    )
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
    args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
    args.attention_dropout = getattr(args, "attention_dropout", 0.0)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.activation_fn = getattr(args, "activation_fn", "relu")
    args.dropout = getattr(args, "dropout", 0.1)
    args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
    args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
    args.share_decoder_input_output_embed = getattr(
        args, "share_decoder_input_output_embed", False
    )
    args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
    args.merge_src_tgt_embed = getattr(args, "merge_src_tgt_embed", False)
    args.no_token_positional_embeddings = getattr(
        args, "no_token_positional_embeddings", False
    )
    args.adaptive_input = getattr(args, "adaptive_input", False)
    args.no_cross_attention = getattr(args, "no_cross_attention", False)
    args.cross_self_attention = getattr(args, "cross_self_attention", False)

    args.decoder_output_dim = getattr(
        args, "decoder_output_dim", args.decoder_embed_dim
    )
    args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)

    args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
    args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
    args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
    args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
    args.offload_activations = getattr(args, "offload_activations", False)
    if args.offload_activations:
        args.checkpoint_activations = True
    args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
    args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
    args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
    args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
    args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)


@register_model_architecture("byteword_transformer", "byteword_transformer_iwslt_de_en")
def byteword_transformer_iwslt_de_en(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
    args.encoder_layers = getattr(args, "encoder_layers", 6)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    byteword_base_architecture(args)


@register_model_architecture("byteword_transformer", "byteword_transformer_wmt_en_de")
def byteword_transformer_wmt_en_de(args):
    byteword_base_architecture(args)


# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture("byteword_transformer", "byteword_transformer_vaswani_wmt_en_de_big")
def byteword_transformer_vaswani_wmt_en_de_big(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
    args.dropout = getattr(args, "dropout", 0.3)
    byteword_base_architecture(args)


@register_model_architecture("byteword_transformer", "byteword_transformer_vaswani_wmt_en_fr_big")
def alone_transformer_vaswani_wmt_en_fr_big(args):
    args.dropout = getattr(args, "dropout", 0.1)
    byteword_transformer_vaswani_wmt_en_de_big(args)


@register_model_architecture("byteword_transformer", "byteword_transformer_wmt_en_de_big")
def alone_transformer_wmt_en_de_big(args):
    args.attention_dropout = getattr(args, "attention_dropout", 0.1)
    byteword_transformer_vaswani_wmt_en_de_big(args)


# default parameters used in tensor2tensor implementation
@register_model_architecture("byteword_transformer", "byteword_transformer_wmt_en_de_big_t2t")
def alone_transformer_wmt_en_de_big_t2t(args):
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
    args.attention_dropout = getattr(args, "attention_dropout", 0.1)
    args.activation_dropout = getattr(args, "activation_dropout", 0.1)
    byteword_transformer_vaswani_wmt_en_de_big(args)

