# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert data2vec checkpoint."""


import argparse
import os
import pathlib

import fairseq
import torch
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version

from transformers_local import (
    Data2VecTextConfig,
    Data2VecTextForMaskedLM,
    Data2VecTextForSequenceClassification,
    Data2VecTextModel,
)
from transformers_local.models.bert.modeling_bert import (
    BertIntermediate,
    BertLayer,
    BertOutput,
    BertSelfAttention,
    BertSelfOutput,
)

# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
from transformers_local.utils import logging


if version.parse(fairseq.__version__) < version.parse("0.9.0"):
    raise Exception("requires fairseq >= 0.9.0")


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SAMPLE_TEXT = "Hello world! cécé herlolip"


def convert_data2vec_checkpoint_to_pytorch(
    data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
):
    """
    Copy/paste/tweak data2vec's weights to our BERT structure.
    """
    data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
    data2vec = Data2VecTextModel.from_pretrained(
        data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
    )
    data2vec.eval()  # disable dropout
    data2vec_model = data2vec.models[0]
    data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
    config = Data2VecTextConfig(
        vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
        hidden_size=data2vec_model.args.encoder_embed_dim,
        num_hidden_layers=data2vec_model.args.encoder_layers,
        num_attention_heads=data2vec_model.args.encoder_attention_heads,
        intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
        max_position_embeddings=514,
        type_vocab_size=1,
        layer_norm_eps=1e-5,  # PyTorch default used in fairseq
    )
    if classification_head:
        config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
    print("Our BERT config:", config)

    model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
    model.eval()

    # Now let's copy all the weights.
    # Embeddings
    model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
    model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
    model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
        model.data2vec_text.embeddings.token_type_embeddings.weight
    )  # just zero them out b/c data2vec doesn't use them.
    model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
    model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias

    for i in range(config.num_hidden_layers):
        # Encoder: start of layer
        layer: BertLayer = model.data2vec_text.encoder.layer[i]
        data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]

        # self attention
        self_attn: BertSelfAttention = layer.attention.self
        assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )
        assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )
        assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )

        self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
        self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
        self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
        self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
        self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
        self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias

        # self-attention output
        self_output: BertSelfOutput = layer.attention.output
        assert (
            self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape
        ), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
        self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
        self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
        self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
        self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias

        # intermediate
        intermediate: BertIntermediate = layer.intermediate
        assert (
            intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape
        ), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
        intermediate.dense.weight = data2vec_layer.fc1.weight
        intermediate.dense.bias = data2vec_layer.fc1.bias

        # output
        bert_output: BertOutput = layer.output
        assert (
            bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape
        ), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
        bert_output.dense.weight = data2vec_layer.fc2.weight
        bert_output.dense.bias = data2vec_layer.fc2.bias
        bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
        bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
        # end of layer

    if classification_head:
        model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
        model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
        model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
        model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
    else:
        # LM Head
        model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
        model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
        model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
        model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
        model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
        model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias

    # Let's check that we get the same results.
    input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0)  # batch of size 1

    our_output = model(input_ids)[0]
    if classification_head:
        their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
    else:
        their_output = data2vec_model(input_ids)[0]
    print(our_output.shape, their_output.shape)
    max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
    print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
    success = torch.allclose(our_output, their_output, atol=1e-3)
    print("Do both models output the same tensors?", "🔥" if success else "💩")
    if not success:
        raise Exception("Something went wRoNg")

    pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
    print(f"Saving model to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    parser.add_argument(
        "--classification_head", action="store_true", help="Whether to convert a final classification head."
    )
    args = parser.parse_args()
    convert_data2vec_checkpoint_to_pytorch(
        args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
    )
