# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""

import argparse
import os
import sys

import tensorrt as trt

from trt_calibrator import FeatEntropyCalibrator

sys.path.append('.')

from fastreid.utils.logger import setup_logger, PathManager

logger = setup_logger(name="trt_export")


def get_parser():
    parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")

    parser.add_argument(
        '--name',
        default='baseline',
        help="name for converted model"
    )
    parser.add_argument(
        '--output',
        default='outputs/trt_model',
        help="path to save converted trt model"
    )
    parser.add_argument(
        '--mode',
        default='fp32',
        help="which mode is used in tensorRT engine, mode can be ['fp32', 'fp16' 'int8']"
    )
    parser.add_argument(
        '--batch-size',
        default=1,
        type=int,
        help="the maximum batch size of trt module"
    )
    parser.add_argument(
        '--height',
        default=256,
        type=int,
        help="input image height"
    )
    parser.add_argument(
        '--width',
        default=128,
        type=int,
        help="input image width"
    )
    parser.add_argument(
        '--channel',
        default=3,
        type=int,
        help="input image channel"
    )
    parser.add_argument(
        '--calib-data',
        default='Market1501',
        help="int8 calibrator dataset name"
    )
    parser.add_argument(
        "--onnx-model",
        default='outputs/onnx_model/baseline.onnx',
        help='path to onnx model'
    )
    return parser


def onnx2trt(
        onnx_file_path,
        save_path,
        mode,
        log_level='ERROR',
        max_workspace_size=1,
        strict_type_constraints=False,
        int8_calibrator=None,
):
    """build TensorRT model from onnx model.
    Args:
        onnx_file_path (string or io object): onnx model name
        save_path (string): tensortRT serialization save path
        mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build.
        log_level (string, default is ERROR): tensorrt logger level, now
            INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
        max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at
            execution time. default is 1GB.
        strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose
            the type constraints that conforms to type constraints. If the flag is not enabled higher precision
            implementation may be chosen if it results in higher performance.
        int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode,
            if None, default calibrator will be used as calibration data.
    """
    mode = mode.lower()
    assert mode in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8'], " \
                                             "but got {}".format(mode)

    trt_logger = trt.Logger(getattr(trt.Logger, log_level))
    builder = trt.Builder(trt_logger)

    logger.info("Loading ONNX file from path {}...".format(onnx_file_path))
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(EXPLICIT_BATCH)
    parser = trt.OnnxParser(network, trt_logger)
    if isinstance(onnx_file_path, str):
        with open(onnx_file_path, 'rb') as f:
            logger.info("Beginning ONNX file parsing")
            flag = parser.parse(f.read())
    else:
        flag = parser.parse(onnx_file_path.read())
    if not flag:
        for error in range(parser.num_errors):
            logger.info(parser.get_error(error))

    logger.info("Completed parsing of ONNX file.")
    # re-order output tensor
    output_tensors = [network.get_output(i) for i in range(network.num_outputs)]
    [network.unmark_output(tensor) for tensor in output_tensors]
    for tensor in output_tensors:
        identity_out_tensor = network.add_identity(tensor).get_output(0)
        identity_out_tensor.name = 'identity_{}'.format(tensor.name)
        network.mark_output(tensor=identity_out_tensor)

    config = builder.create_builder_config()
    config.max_workspace_size = max_workspace_size * (1 << 25)
    if mode == 'fp16':
        assert builder.platform_has_fast_fp16, "not support fp16"
        builder.fp16_mode = True
    if mode == 'int8':
        assert builder.platform_has_fast_int8, "not support int8"
        builder.int8_mode = True
        builder.int8_calibrator = int8_calibrator

    if strict_type_constraints:
        config.set_flag(trt.BuilderFlag.STRICT_TYPES)

    logger.info("Building an engine from file {}; this may take a while...".format(onnx_file_path))
    engine = builder.build_cuda_engine(network)
    logger.info("Create engine successfully!")

    logger.info("Saving TRT engine file to path {}".format(save_path))
    with open(save_path, 'wb') as f:
        f.write(engine.serialize())
    logger.info("Engine file has already saved to {}!".format(save_path))


if __name__ == '__main__':
    args = get_parser().parse_args()

    onnx_file_path = args.onnx_model
    engineFile = os.path.join(args.output, args.name + '.engine')

    if args.mode.lower() == 'int8':
        int8_calib = FeatEntropyCalibrator(args)
    else:
        int8_calib = None

    PathManager.mkdirs(args.output)
    onnx2trt(onnx_file_path, engineFile, args.mode, int8_calibrator=int8_calib)
