#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """

import argparse
import logging
import os
import sys

import numpy as np
import onnxruntime
import torch
from bart_onnx.generation_onnx import BARTBeamSearchGenerator
from bart_onnx.reduce_onnx_size import remove_dup_initializers

import transformers
from transformers import BartForConditionalGeneration, BartTokenizer


logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s |  [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)

logger = logging.getLogger(__name__)

model_dict = {"facebook/bart-base": BartForConditionalGeneration}
tokenizer_dict = {"facebook/bart-base": BartTokenizer}


def parse_args():
    parser = argparse.ArgumentParser(description="Export Bart model + Beam Search to ONNX graph.")
    parser.add_argument(
        "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=5,
        help="The maximum total input sequence length after tokenization.",
    )
    parser.add_argument(
        "--num_beams",
        type=int,
        default=None,
        help=(
            "Number of beams to use for evaluation. This argument will be "
            "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
        ),
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--config_name",
        type=str,
        default=None,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="Device where the model will be run",
    )
    parser.add_argument("--output_file_path", type=str, default=None, help="Where to store the final ONNX file.")

    args = parser.parse_args()

    return args


def load_model_tokenizer(model_name, device="cpu"):
    huggingface_model = model_dict[model_name].from_pretrained(model_name).to(device)
    tokenizer = tokenizer_dict[model_name].from_pretrained(model_name)

    if model_name in ["facebook/bart-base"]:
        huggingface_model.config.no_repeat_ngram_size = 0
        huggingface_model.config.forced_bos_token_id = None
        huggingface_model.config.min_length = 0

    return huggingface_model, tokenizer


def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_length):
    model.eval()

    ort_sess = None
    bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model))

    with torch.no_grad():
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
        inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)

        summary_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            num_beams=num_beams,
            max_length=max_length,
            early_stopping=True,
            decoder_start_token_id=model.config.decoder_start_token_id,
        )

        torch.onnx.export(
            bart_script_model,
            (
                inputs["input_ids"],
                inputs["attention_mask"],
                num_beams,
                max_length,
                model.config.decoder_start_token_id,
            ),
            onnx_file_path,
            opset_version=14,
            input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
            output_names=["output_ids"],
            dynamic_axes={
                "input_ids": {0: "batch", 1: "seq"},
                "output_ids": {0: "batch", 1: "seq_out"},
            },
            example_outputs=summary_ids,
        )

        logger.info("Model exported to {}".format(onnx_file_path))

        new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))

        logger.info("Deduplicated and optimized model written to {}".format(new_onnx_file_path))

        ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
        ort_out = ort_sess.run(
            None,
            {
                "input_ids": inputs["input_ids"].cpu().numpy(),
                "attention_mask": inputs["attention_mask"].cpu().numpy(),
                "num_beams": np.array(num_beams),
                "max_length": np.array(max_length),
                "decoder_start_token_id": np.array(model.config.decoder_start_token_id),
            },
        )

        np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)

        logger.info("Model outputs from torch and ONNX Runtime are similar.")
        logger.info("Success.")


def main():
    args = parse_args()
    max_length = 5
    num_beams = 4

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    logger.setLevel(logging.INFO)
    transformers.utils.logging.set_verbosity_error()

    device = torch.device(args.device)

    model, tokenizer = load_model_tokenizer(args.model_name_or_path, device)

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    model.to(device)

    if args.max_length:
        max_length = args.max_length

    if args.num_beams:
        num_beams = args.num_beams

    if args.output_file_path:
        output_name = args.output_file_path
    else:
        output_name = "BART.onnx"

    logger.info("Exporting model to ONNX")
    export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)


if __name__ == "__main__":
    main()
