import time 
import contextlib
import numpy as np
from PIL import Image
from collections import OrderedDict

import onnx
import torch 
import onnx_graphsurgeon


def to_binary_data(path, size=(640, 640), output_name='input_tensor.bin'):
    '''--loadInputs='image:input_tensor.bin'
    '''
    im = Image.open(path).resize(size)
    data = np.asarray(im, dtype=np.float32).transpose(2, 0, 1)[None] / 255.
    data.tofile(output_name)


def yolo_insert_nms(path, score_threshold=0.01, iou_threshold=0.7, max_output_boxes=300, simplify=False):
    onnx_model = onnx.load(path)

    if simplify:
        from onnxsim import simplify
        onnx_model, _ = simplify(onnx_model,  overwrite_input_shapes={'image': [1, 3, 640, 640]})

    graph = onnx_graphsurgeon.import_onnx(onnx_model)
    graph.toposort()
    graph.fold_constants()
    graph.cleanup()

    topk = max_output_boxes
    attrs = OrderedDict(plugin_version='1',
                        background_class=-1,
                        max_output_boxes=topk,
                        score_threshold=score_threshold,
                        iou_threshold=iou_threshold,
                        score_activation=False,
                        box_coding=0, )

    outputs = [onnx_graphsurgeon.Variable('num_dets', np.int32, [-1, 1]),
               onnx_graphsurgeon.Variable('det_boxes', np.float32, [-1, topk, 4]),
               onnx_graphsurgeon.Variable('det_scores', np.float32, [-1, topk]),
               onnx_graphsurgeon.Variable('det_classes', np.int32, [-1, topk])]

    graph.layer(op='EfficientNMS_TRT', 
                name="batched_nms", 
                inputs=[graph.outputs[0], 
                        graph.outputs[1]], 
                outputs=outputs, 
                attrs=attrs, )

    graph.outputs = outputs
    graph.cleanup().toposort()

    onnx.save(onnx_graphsurgeon.export_onnx(graph), f'yolo_w_nms.onnx')


class TimeProfiler(contextlib.ContextDecorator):
    def __init__(self, ):
        self.total = 0
        
    def __enter__(self, ):
        self.start = self.time()
        return self 
    
    def __exit__(self, type, value, traceback):
        self.total += self.time() - self.start
    
    def reset(self, ):
        self.total = 0
    
    def time(self, ):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        return time.time()
