# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert data2vec checkpoint."""


import argparse
import os
import pathlib

import fairseq
import torch
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version

from transformers import Data2VecTextConfig, Data2VecTextForMaskedLM, Data2VecTextForSequenceClassification
from transformers.models.bert.modeling_bert import (
    BertIntermediate,
    BertLayer,
    BertOutput,
    BertSelfAttention,
    BertSelfOutput,
)

# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
from transformers.models.data2vec.data2vec_text import Data2VecTextModel
from transformers.utils import logging


if version.parse(fairseq.__version__) < version.parse("0.9.0"):
    raise Exception("requires fairseq >= 0.9.0")


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SAMPLE_TEXT = "Hello world! cécé herlolip"


def convert_data2vec_checkpoint_to_pytorch(
    data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
):
    """
    Copy/paste/tweak data2vec's weights to our BERT structure.
    """
    data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
    data2vec = Data2VecTextModel.from_pretrained(
        data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
    )
    data2vec.eval()  # disable dropout
    data2vec_model = data2vec.models[0]
    data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
    config = Data2VecTextConfig(
        vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
        hidden_size=data2vec_model.args.encoder_embed_dim,
        num_hidden_layers=data2vec_model.args.encoder_layers,
        num_attention_heads=data2vec_model.args.encoder_attention_heads,
        intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
        max_position_embeddings=514,
        type_vocab_size=1,
        layer_norm_eps=1e-5,  # PyTorch default used in fairseq
    )
    if classification_head:
        config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
    print("Our BERT config:", config)

    model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
    model.eval()

    # Now let's copy all the weights.
    # Embeddings
    model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
    model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
    model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
        model.data2vec_text.embeddings.token_type_embeddings.weight
    )  # just zero them out b/c data2vec doesn't use them.
    model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
    model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias

    for i in range(config.num_hidden_layers):
        # Encoder: start of layer
        layer: BertLayer = model.data2vec_text.encoder.layer[i]
        data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]

        # self attention
        self_attn: BertSelfAttention = layer.attention.self
        assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )
        assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )
        assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
            (config.hidden_size, config.hidden_size)
        ), (
            "Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
        )

        self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
        self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
        self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
        self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
        self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
        self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias

        # self-attention output
        self_output: BertSelfOutput = layer.attention.output
        assert (
            self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape
        ), f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
        self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
        self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
        self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
        self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias

        # intermediate
        intermediate: BertIntermediate = layer.intermediate
        assert (
            intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape
        ), f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
        intermediate.dense.weight = data2vec_layer.fc1.weight
        intermediate.dense.bias = data2vec_layer.fc1.bias

        # output
        bert_output: BertOutput = layer.output
        assert (
            bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape
        ), f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
        bert_output.dense.weight = data2vec_layer.fc2.weight
        bert_output.dense.bias = data2vec_layer.fc2.bias
        bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
        bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
        # end of layer

    if classification_head:
        model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
        model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
        model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
        model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
    else:
        # LM Head
        model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
        model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
        model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
        model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
        model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
        model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias

    # Let's check that we get the same results.
    input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0)  # batch of size 1

    our_output = model(input_ids)[0]
    if classification_head:
        their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
    else:
        their_output = data2vec_model(input_ids)[0]
    print(our_output.shape, their_output.shape)
    max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
    print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
    success = torch.allclose(our_output, their_output, atol=1e-3)
    print("Do both models output the same tensors?", "🔥" if success else "💩")
    if not success:
        raise Exception("Something went wRoNg")

    pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
    print(f"Saving model to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    parser.add_argument(
        "--classification_head", action="store_true", help="Whether to convert a final classification head."
    )
    args = parser.parse_args()
    convert_data2vec_checkpoint_to_pytorch(
        args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
    )
