import argparse
import os.path as osp

import numpy as np
import onnx
import os
#import onnxruntime as rt
import torch

from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model,
                        preprocess_example_input)

#from mmdet.models import build

def pytorch2onnx(config_path,
                 checkpoint_path,
                 input_img,
                 input_shape,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 simplify = True,
                 dynamic = True,
                 normalize_cfg=None,
                 dataset='coco',
                 test_img=None):

    input_config = {
        'input_shape': input_shape,
        'input_path': input_img,
        'normalize_cfg': normalize_cfg
    }
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    tmp_ckpt_file = None
    # remove optimizer for smaller file size
    if 'optimizer' in checkpoint:
        del checkpoint['optimizer']
        tmp_ckpt_file = checkpoint_path+"_slim.pth"
        torch.save(checkpoint, tmp_ckpt_file)
        print('remove optimizer params and save to', tmp_ckpt_file)
        checkpoint_path = tmp_ckpt_file

    model, tensor_data = generate_inputs_and_wrap_model(
        config_path, checkpoint_path, input_config)

    if tmp_ckpt_file is not None:
        os.remove(tmp_ckpt_file)

    if simplify or dynamic:
        ori_output_file = output_file.split('.')[0]+"_ori.onnx"
    else:
        ori_output_file = output_file
    torch.onnx.export(
        model,
        tensor_data,
        ori_output_file,
        keep_initializers_as_inputs=False,
        verbose=False,
        opset_version=opset_version)
    if simplify or dynamic:
        model = onnx.load(ori_output_file)
        if dynamic:
            model.graph.input[0].type.tensor_type.shape.dim[2].dim_param = '?'
            model.graph.input[0].type.tensor_type.shape.dim[3].dim_param = '?'
        if simplify:
            from onnxsim import simplify
            #print(model.graph.input[0])
            if dynamic:
                input_shapes = {model.graph.input[0].name : list(input_shape)}
                model, check = simplify(model, input_shapes=input_shapes, dynamic_input_shape=True)
                print('onnxsim with dynamic input')
            else:
                model, check = simplify(model)
                print('onnxsim with fixed input')
            assert check, "Simplified ONNX model could not be validated"
        onnx.save(model, output_file)
        os.remove(ori_output_file)


    print(f'Successfully exported ONNX model: {output_file}')


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert MMDetection models to ONNX')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--input-img', type=str, help='Images for input')
    parser.add_argument('--show', action='store_true', help='show onnx graph')
    parser.add_argument('--output-file', type=str, default='')
    parser.add_argument('--opset-version', type=int, default=11)
    parser.add_argument(
        '--test-img', type=str, default=None, help='Images for test')
    parser.add_argument(
        '--dataset', type=str, default='coco', help='Dataset name')
    parser.add_argument(
        '--verify',
        action='store_true',
        help='verify the onnx model output against pytorch output')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        #default=[640, 640],
        #default=[384, 384],
        default=[-1, -1],
        help='input image size')
    parser.add_argument(
        '--mean',
        type=float,
        nargs='+',
        default=[127.5, 127.5, 127.5],
        help='mean value used for preprocess input data')
    parser.add_argument(
        '--std',
        type=float,
        nargs='+',
        default=[128.0, 128.0, 128.0],
        help='variance value used for preprocess input data')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    assert args.opset_version == 11, 'MMDet only support opset 11 now'

    if not args.input_img:
        args.input_img = osp.join(
            osp.dirname(__file__), '../tests/data/t1.jpg')

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (1, 3) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    assert len(args.mean) == 3
    assert len(args.std) == 3

    simplify = True
    dynamic = False
    if input_shape[2]<=0 or input_shape[3]<=0:
        #input_shape = (1,3,640,640)
        input_shape = (1,3,320,320)
        dynamic = True
        #simplify = False
        print('set to dynamic input with dummy shape:', input_shape)

    normalize_cfg = {'mean': args.mean, 'std': args.std}

    if len(args.output_file)==0:
        output_dir = osp.join(osp.dirname(__file__), '../onnx')
        if not osp.exists(output_dir):
            os.makedirs(output_dir)
        cfg_name = args.config.split('/')[-1]
        pos = cfg_name.rfind('.')
        cfg_name = cfg_name[:pos]
        if dynamic:
            args.output_file = osp.join(output_dir, "%s.onnx"%cfg_name)
        else:
            args.output_file = osp.join(output_dir, "%s_shape%dx%d.onnx"%(cfg_name,input_shape[2],input_shape[3]))

    # convert model to onnx file
    pytorch2onnx(
        args.config,
        args.checkpoint,
        args.input_img,
        input_shape,
        opset_version=args.opset_version,
        show=args.show,
        output_file=args.output_file,
        verify=args.verify,
        simplify = simplify,
        dynamic = dynamic,
        normalize_cfg=normalize_cfg,
        dataset=args.dataset,
        test_img=args.test_img)
