# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Callback to export model for inference."""

from __future__ import annotations

import logging
from copy import deepcopy
from typing import Any, Optional, Sequence, Union

import torch.nn as nn

from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ExportFormat, ObjectStore, Transform, export_with_logger

log = logging.getLogger(__name__)

__all__ = ['ExportForInferenceCallback']


class ExportForInferenceCallback(Callback):
    """Callback to export model for inference.

    Example:
        .. doctest::

            >>> from composer import Trainer
            >>> from composer.callbacks import ExportForInferenceCallback
            >>> # constructing trainer object with this callback
            >>> trainer = Trainer(
            ...     model=model,
            ...     train_dataloader=train_dataloader,
            ...     eval_dataloader=eval_dataloader,
            ...     optimizers=optimizer,
            ...     max_duration="1ep",
            ...     callbacks=[ExportForInferenceCallback(save_format='torchscript',save_path='/tmp/model.pth')],
            ... )

    Args:
        save_format (Union[str, ExportFormat]):  Format to export to. Either ``"torchscript"`` or ``"onnx"``.
        save_path (str): The path for storing the exported model. It can be a path to a file on the local disk,
            a URL, or if ``save_object_store`` is set, the object name
            in a cloud bucket. For example, ``my_run/exported_model``.
        save_object_store (ObjectStore, optional): If the ``save_path`` is in an object name in a cloud bucket
            (i.e. AWS S3 or Google Cloud Storage), an instance of
            :class:`~.ObjectStore` which will be used
            to store the exported model. If this is set to ``None``,  will save to ``save_path`` using the logger.
            (default: ``None``)
        sample_input (Any, optional): Example model inputs used for tracing. This is needed for "onnx" export
        transforms (Sequence[Transform], optional): transformations (usually optimizations) that should
            be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
        input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
            to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
        output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
            to ``None``, it defaults to ``["output"]``.
    """

    def __init__(
        self,
        save_format: Union[str, ExportFormat],
        save_path: str,
        save_object_store: Optional[ObjectStore] = None,
        sample_input: Optional[Any] = None,
        transforms: Optional[Sequence[Transform]] = None,
        input_names: Optional[Sequence[str]] = None,
        output_names: Optional[Sequence[str]] = None,
    ):
        self.save_format = save_format
        self.save_path = save_path
        self.save_object_store = save_object_store
        self.sample_input = sample_input
        self.transforms = transforms
        self.input_names = input_names
        self.output_names = output_names

    def after_dataloader(self, state: State, logger: Logger) -> None:
        del logger
        if self.sample_input is None and self.save_format == 'onnx':
            self.sample_input = deepcopy(state.batch)

    def fit_end(self, state: State, logger: Logger):
        self.export_model(state, logger)

    def export_model(self, state: State, logger: Logger):
        export_model = state.model.module if state.is_model_ddp else state.model
        if not isinstance(export_model, nn.Module):
            raise ValueError(f'Exporting Model requires type torch.nn.Module, got {type(export_model)}')
        export_with_logger(
            model=export_model,
            save_format=self.save_format,
            save_path=self.save_path,
            logger=logger,
            save_object_store=self.save_object_store,
            sample_input=(self.sample_input, {}),
            transforms=self.transforms,
            input_names=self.input_names,
            output_names=self.output_names,
        )
