# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
"""
import argparse

import torch

from transformers import GPT2LMHeadModel, RobertaForMaskedLM


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
    )
    parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
    parser.add_argument("--model_name", default="roberta-large", type=str)
    parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
    parser.add_argument("--vocab_transform", action="store_true")
    args = parser.parse_args()

    if args.model_type == "roberta":
        model = RobertaForMaskedLM.from_pretrained(args.model_name)
        prefix = "roberta"
    elif args.model_type == "gpt2":
        model = GPT2LMHeadModel.from_pretrained(args.model_name)
        prefix = "transformer"

    state_dict = model.state_dict()
    compressed_sd = {}

    # Embeddings #
    if args.model_type == "gpt2":
        for param_name in ["wte.weight", "wpe.weight"]:
            compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
    else:
        for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
            param_name = f"{prefix}.embeddings.{w}.weight"
            compressed_sd[param_name] = state_dict[param_name]
        for w in ["weight", "bias"]:
            param_name = f"{prefix}.embeddings.LayerNorm.{w}"
            compressed_sd[param_name] = state_dict[param_name]

    # Transformer Blocks #
    std_idx = 0
    for teacher_idx in [0, 2, 4, 7, 9, 11]:
        if args.model_type == "gpt2":
            for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
                for w in ["weight", "bias"]:
                    compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
                        f"{prefix}.h.{teacher_idx}.{layer}.{w}"
                    ]
            compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
        else:
            for layer in [
                "attention.self.query",
                "attention.self.key",
                "attention.self.value",
                "attention.output.dense",
                "attention.output.LayerNorm",
                "intermediate.dense",
                "output.dense",
                "output.LayerNorm",
            ]:
                for w in ["weight", "bias"]:
                    compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
                        f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
                    ]
        std_idx += 1

    # Language Modeling Head ###s
    if args.model_type == "roberta":
        for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
            compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
        if args.vocab_transform:
            for w in ["weight", "bias"]:
                compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
                compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
    elif args.model_type == "gpt2":
        for w in ["weight", "bias"]:
            compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
        compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]

    print(f"N layers selected for distillation: {std_idx}")
    print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")

    print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
    torch.save(compressed_sd, args.dump_checkpoint)
