# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Bros checkpoints."""

import argparse

import bros  # original repo
import torch

from transformers import BrosConfig, BrosModel, BrosProcessor
from transformers.utils import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


def get_configs(model_name):
    bros_config = BrosConfig.from_pretrained(model_name)
    return bros_config


def remove_ignore_keys_(state_dict):
    ignore_keys = [
        "embeddings.bbox_sinusoid_emb.inv_freq",
    ]
    for k in ignore_keys:
        state_dict.pop(k, None)


def rename_key(name):
    if name == "embeddings.bbox_projection.weight":
        name = "bbox_embeddings.bbox_projection.weight"

    if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
        name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"

    if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
        name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"

    return name


def convert_state_dict(orig_state_dict, model):
    # rename keys
    for key in orig_state_dict.copy().keys():
        val = orig_state_dict.pop(key)
        orig_state_dict[rename_key(key)] = val

    # remove ignore keys
    remove_ignore_keys_(orig_state_dict)

    return orig_state_dict


def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
    # load original model
    original_model = bros.BrosModel.from_pretrained(model_name).eval()

    # load HuggingFace Model
    bros_config = get_configs(model_name)
    model = BrosModel.from_pretrained(model_name, config=bros_config)
    model.eval()

    state_dict = original_model.state_dict()
    new_state_dict = convert_state_dict(state_dict, model)
    model.load_state_dict(new_state_dict)

    # verify results

    # original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
    bbox = torch.tensor(
        [
            [
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                [0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
                [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
                [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
                [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
                [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
                [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
            ]
        ]
    )

    processor = BrosProcessor.from_pretrained(model_name)

    encoding = processor("His name is Rocco.", return_tensors="pt")
    encoding["bbox"] = bbox

    original_hidden_states = original_model(**encoding).last_hidden_state
    # pixel_values = processor(image, return_tensors="pt").pixel_values

    last_hidden_states = model(**encoding).last_hidden_state

    assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)

    if pytorch_dump_folder_path is not None:
        print(f"Saving model and processor to {pytorch_dump_folder_path}")
        model.save_pretrained(pytorch_dump_folder_path)
        processor.save_pretrained(pytorch_dump_folder_path)

    if push_to_hub:
        model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
        processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--model_name",
        default="jinho8345/bros-base-uncased",
        required=False,
        type=str,
        help="Name of the original model you'd like to convert.",
    )
    parser.add_argument(
        "--pytorch_dump_folder_path",
        default=None,
        required=False,
        type=str,
        help="Path to the output PyTorch model directory.",
    )
    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Whether or not to push the converted model and processor to the 🤗 hub.",
    )

    args = parser.parse_args()
    convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
