import torch
import torchvision.models as models
import onnx
import os
import subprocess

import torch.nn as nn

# Can import gpu_id from external unified import
try:
    from build_index import gpu_id
except ImportError:
    gpu_id = 0

class ResNet50_Feature(nn.Module):
    def __init__(self):
        super().__init__()
        m = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*(list(m.children())[:-1]))  # Remove fc layer
    def forward(self, x):
        x = self.features(x)
        return x.view(x.size(0), -1)  # (batch, 2048)

def export_onnx(onnx_path='resnet50.onnx', input_shape=(1, 3, 112, 112)):
    model = ResNet50_Feature()
    model.eval()
    dummy = torch.randn(*input_shape)
    torch.onnx.export(model, dummy, onnx_path, input_names=['input'], output_names=['output'],
                      opset_version=11, do_constant_folding=True, dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    print(f'Exported ONNX to {onnx_path} (feature output, 2048-d)')

def export_trt(onnx_path, engine_path, fp16=True, max_batch=32):
    # Need to install trtexec tool (comes with TensorRT)
    trtexec = 'trtexec'
    cmd = [
        trtexec,
        f'--onnx={onnx_path}',
        f'--saveEngine={engine_path}',
        f'--minShapes=input:1x3x112x112',
        f'--optShapes=input:{max_batch}x3x112x112',
        f'--maxShapes=input:{max_batch}x3x112x112',
        f'--device={gpu_id}'
    ]
    if fp16:
        cmd.append('--fp16')
    print(' '.join(cmd))
    subprocess.run(cmd, check=True)
    print(f'Exported TensorRT engine to {engine_path}')

if __name__ == '__main__':
    onnx_path = '../tensorrt/resnet50.onnx'
    engine_path = '../tensorrt/resnet50_fp16.engine'
    export_onnx(onnx_path=onnx_path)
    export_trt(onnx_path, engine_path, fp16=True, max_batch=2048)
