#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert gpt2."""

import argparse
import pickle

import numpy as np
import paddle
import paddle.fluid as fluid
import torch


def setup_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--param_path", type=str, default="CPM/pytorch_model.bin")
    parser.add_argument("--save_path", type=str, default="paddle_CPM")
    parser.add_argument("--convert_type", type=str, default="torch2paddle",
                        choices=["torch2paddle"])

    return parser.parse_args()


def convert_fn(state_dict):
    new_state_dict = {}
    for k in state_dict:

        filter_flag = 0  # filter params
        is_q_k_v = 0

        weight = state_dict[k].numpy()
        names = k.split(".")
        new_names = []
        if names[0] == "transformer":
            if names[1] == "wte":
                new_names.append("word_embedding")
            elif names[1] == "wpe":
                new_names.append("pos_embedding")
            elif names[1] == "h":
                new_names.append(f"encoder_layer_{int(names[2])}")  # must be a number
                if names[3] == "ln_1":
                    if names[4] == "weight":
                        new_names.append("pre_att_layer_norm_scale")
                    elif names[4] == "bias":
                        new_names.append("pre_att_layer_norm_bias")
                    else:
                        raise ValueError
                elif names[3] == "ln_2":
                    if names[4] == "weight":
                        new_names.append("pre_ffn_layer_norm_scale")
                    elif names[4] == "bias":
                        new_names.append("pre_ffn_layer_norm_bias")
                    else:
                        raise ValueError
                elif names[3] == "mlp":
                    tmp = ".".join(names[-2:])
                    if tmp == "c_fc.weight":
                        new_names.append("ffn_fc_0.w_0")
                    elif tmp == "c_fc.bias":
                        new_names.append("ffn_fc_0.b_0")
                    elif tmp == "c_proj.weight":
                        new_names.append("ffn_fc_1.w_0")
                    elif tmp == "c_proj.bias":
                        new_names.append("ffn_fc_1.b_0")
                    else:
                        raise ValueError
                elif names[3] == "attn":
                    tmp = ".".join(names[-2:])
                    if tmp == "c_attn.weight":
                        q, key, v = np.split(weight, 3, axis=1)
                        print(k, weight.shape, "->", q.shape, key.shape, v.shape)
                        # q, key, v = np.transpose(q), np.transpose(key), np.transpose(v)
                        head = "_".join(new_names) + "_"
                        q_name, k_name, v_name = \
                            head+"multi_head_att_query_fc.w_0", head+"multi_head_att_key_fc.w_0", head+"multi_head_att_value_fc.w_0"

                        new_state_dict[q_name] = q
                        new_state_dict[k_name] = key
                        new_state_dict[v_name] = v
                        is_q_k_v = 1
                    elif tmp == "c_attn.bias":
                        q, key, v = np.split(weight, 3, axis=0)
                        # print(k, weight.shape, "->", q.shape, key.shape, v.shape)
                        q, key, v = np.transpose(q), np.transpose(key), np.transpose(v)
                        head = "_".join(new_names) + "_"
                        q_name, k_name, v_name = \
                            head + "multi_head_att_query_fc.b_0", head + "multi_head_att_key_fc.b_0", head + "multi_head_att_value_fc.b_0"

                        new_state_dict[q_name] = q
                        new_state_dict[k_name] = key
                        new_state_dict[v_name] = v
                        is_q_k_v = 1
                    elif tmp == "c_proj.weight":
                        new_names.append("multi_head_att_output_fc.w_0")
                    elif tmp == "c_proj.bias":
                        new_names.append("multi_head_att_output_fc.b_0")
                    elif tmp in ["attn.bias", "attn.masked_bias"]:  # some weight which i don't know how to use
                        filter_flag = 1
                    else:
                        raise ValueError
            elif names[1] == "ln_f":
                if names[2] == "weight":
                    new_names.append("post_encoder_layer_norm_scale")
                elif names[2] == "bias":
                    new_names.append("post_encoder_layer_norm_bias")
        elif names[0] == "lm_head":
            filter_flag = 1
        else:
            raise ValueError

        if filter_flag == 1 or is_q_k_v == 1:
            continue

        new_k = "_".join(new_names)
        print(k, "->", new_k, weight.shape)

        # if len(weight.shape) == 2 and "embedding" not in new_k:
            # weight = np.transpose(weight)
        new_state_dict[new_k] = weight

    return new_state_dict


def main(args):
    if args.convert_type == "torch2paddle":
        paddle.enable_static()
        program = fluid.Program()
        state_dict = torch.load(open(args.param_path, "rb"), map_location="cpu")
        state_dict = convert_fn(state_dict)
        for k in state_dict:
            weight = state_dict[k]
            param = program.global_block().create_parameter(
                shape=weight.shape,
                dtype=weight.dtype,
                name=k)

        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(program)

        for k in state_dict:
            param_tensor = fluid.global_scope().find_var(k).get_tensor()
            param_tensor.set(state_dict[k], exe.place)

        fluid.io.save_params(exe, args.save_path, main_program=program)
    else:
        raise ValueError(f"convert_type: {args.convert_type} is not supported now.")

    return


if __name__ == "__main__":
    args = setup_args()
    main(args)
