# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com

Create custom calibrator, use to calibrate int8 TensorRT model.
Need to override some methods of trt.IInt8EntropyCalibrator2, such as get_batch_size, get_batch,
read_calibration_cache, write_calibration_cache.
"""

# based on:
# https://github.com/qq995431104/Pytorch2TensorRT/blob/master/myCalibrator.py

import os
import sys

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

import numpy as np
import torchvision.transforms as T

sys.path.append('../..')

from fastreid.data.build import _root
from fastreid.data.data_utils import read_image
from fastreid.data.datasets import DATASET_REGISTRY
import logging

from fastreid.data.transforms import ToTensor


logger = logging.getLogger('trt_export.calibrator')


class FeatEntropyCalibrator(trt.IInt8EntropyCalibrator2):

    def __init__(self, args):
        trt.IInt8EntropyCalibrator2.__init__(self)

        self.cache_file = 'reid_feat.cache'

        self.batch_size = args.batch_size
        self.channel = args.channel
        self.height = args.height
        self.width = args.width
        self.transform = T.Compose([
            T.Resize((self.height, self.width), interpolation=3),  # [h,w]
            ToTensor(),
        ])

        dataset = DATASET_REGISTRY.get(args.calib_data)(root=_root)
        self._data_items = dataset.train + dataset.query + dataset.gallery
        np.random.shuffle(self._data_items)
        self.imgs = [item[0] for item in self._data_items]

        self.batch_idx = 0
        self.max_batch_idx = len(self.imgs) // self.batch_size

        self.data_size = self.batch_size * self.channel * self.height * self.width * trt.float32.itemsize
        self.device_input = cuda.mem_alloc(self.data_size)

    def next_batch(self):
        if self.batch_idx < self.max_batch_idx:
            batch_files = self.imgs[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size]
            batch_imgs = np.zeros((self.batch_size, self.channel, self.height, self.width),
                                  dtype=np.float32)
            for i, f in enumerate(batch_files):
                img = read_image(f)
                img = self.transform(img).numpy()
                assert (img.nbytes == self.data_size // self.batch_size), 'not valid img!' + f
                batch_imgs[i] = img
            self.batch_idx += 1
            logger.info("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx))
            return np.ascontiguousarray(batch_imgs)
        else:
            return np.array([])

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names, p_str=None):
        try:
            batch_imgs = self.next_batch()
            batch_imgs = batch_imgs.ravel()
            if batch_imgs.size == 0 or batch_imgs.size != self.batch_size * self.channel * self.height * self.width:
                return None
            cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32))
            return [int(self.device_input)]
        except:
            return None

    def read_calibration_cache(self):
        # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)
