# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as dist
from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS
from mmengine.structures import BaseDataElement
from mmengine.model import MMFullyShardedDataParallel as _MMFullyShardedDataParallel
from mmengine.model import BaseModel


@MODEL_WRAPPERS.register_module(force=True)
class MMFullyShardedDataParallel(_MMFullyShardedDataParallel):
    module: BaseModel

    def train_step(
        self, data: dict, optim_wrapper: OptimWrapper
    ) -> Dict[str, torch.Tensor]:
        """Interface for model forward, backward and parameters updating during
        training process.

        :meth:`train_step` will perform the following steps in order:

        - If :attr:`module` defines the preprocess method,
            call ``module.preprocess`` to pre-processing data.
        - Call ``module.forward(**data)`` and get losses.
        - Parse losses.
        - Call ``optim_wrapper.optimizer_step`` to update parameters.
        - Return log messages of losses.

        Args:
            data (dict): Data sampled by dataloader.
            optim_wrapper (OptimWrapper): A wrapper of optimizer to
                update parameters.

        Returns:
            Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
        """
        # enable automatic mixed precision training context.
        with optim_wrapper.optim_context(self):
            data = self.module.data_preprocessor(data, training=True)
            if isinstance(data, dict):
                losses = self(data, mode="loss")
            else:
                raise TypeError(
                    "Output of `data_preprocessor` should be "
                    f"list tuple or dict, but got {type(data)}"
                )
        parsed_loss, log_vars = self.module.parse_losses(losses)
        optim_wrapper.update_params(parsed_loss)
        return log_vars

    def _run_forward(
        self, data: Union[dict, tuple, list], mode: str
    ) -> Union[Dict[str, torch.Tensor], list]:
        """Unpacks data for :meth:`forward`
        Args:
            data (dict or tuple or list): Data sampled from dataset.
            mode (str): Mode of forward.
        Returns:
            dict or list: Results of training or testing mode.
        """
        if isinstance(data, dict):
            results = self(data, mode=mode)
        else:
            raise TypeError(
                "Output of `data_preprocessor` should be " f"dict, but got {type(data)}"
            )
        return results
