# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp

import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from torch import nn

from openmixup.models import build_model

torch.manual_seed(0)


def _demo_inputs(input_shape, num_classes):
    """Create a superset of inputs needed to run test or train batches.

    Args:
        input_shape (tuple): input batch dimensions
        num_classes (int): number of semantic classes
    """
    (N, C, H, W) = input_shape
    rng = np.random.RandomState(0)
    imgs = rng.rand(*input_shape)
    gt_label = rng.randint(
        low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
    inputs = {
        'img': torch.FloatTensor(imgs).requires_grad_(True),
        'gt_label': torch.LongTensor(gt_label),
    }
    return inputs


def pytorch2torchscript(model: nn.Module,
                        input_shape: tuple,
                        output_file: str,
                        verify: bool):
    """Export Pytorch model to TorchScript model through torch.jit.trace and
    verify the outputs are same between Pytorch and TorchScript.

    Args:
        model (nn.Module): Pytorch model we want to export.
        input_shape (tuple): Use this input shape to construct
            the corresponding dummy input and execute the model.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output
            TorchScript model.
        verify (bool): Whether compare the outputs between Pytorch
            and TorchScript through loading generated output_file.
    """
    model.cpu().eval()

    if hasattr(model.head, 'num_classes'):
        num_classes = model.head.num_classes
    # Some backbones use `num_classes=-1` to disable top classifier.
    elif getattr(model.backbone, 'num_classes', -1) > 0:
        num_classes = model.backbone.num_classes
    else:
        raise AttributeError('Cannot find "num_classes" in both head and '
                             'backbone, please check the config file.')

    inputs = _demo_inputs(input_shape, num_classes)

    img = inputs.pop('img')

    # replace original forward function
    origin_forward = model.forward
    model.forward = model.forward_inference

    with torch.no_grad():
        trace_model = torch.jit.trace(model, img)
        save_dir, _ = osp.split(output_file)
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
        trace_model.save(output_file)
        print(f'Successfully exported TorchScript model: {output_file}')
    model.forward = origin_forward

    if verify:
        # load by torch.jit
        jit_model = torch.jit.load(output_file)

        # check the numerical value
        # get pytorch output
        pytorch_result = model(img, mode='inference').detach().numpy()

        # get jit output
        jit_result = jit_model(img)[0].detach().numpy()
        if not np.allclose(pytorch_result, jit_result):
            raise ValueError(
                'The outputs are different between Pytorch and TorchScript')
        print('The outputs are same between Pytorch and TorchScript')


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert MMCls to TorchScript')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--checkpoint', help='checkpoint file', type=str)
    parser.add_argument(
        '--verify',
        action='store_true',
        help='verify the TorchScript model',
        default=False)
    parser.add_argument('--output-file', type=str, default='tmp.pt')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[224, 224],
        help='input image size')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (
            1,
            3,
        ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None

    # build the model and load checkpoint
    classifier = build_model(cfg.model)

    if args.checkpoint:
        load_checkpoint(classifier, args.checkpoint, map_location='cpu')

    # convert model to TorchScript file
    pytorch2torchscript(
        classifier,
        input_shape,
        output_file=args.output_file,
        verify=args.verify)

# Usage: docs/en/tools/pytorch2onnx.md
