# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Union

import torch

from mmengine.optim import OptimWrapper
from mmhug.registry import MODEL_WRAPPERS
from mmengine.model import MMDistributedDataParallel as _MMDistributedDataParallel
from mmengine.model.utils import detect_anomalous_params


@MODEL_WRAPPERS.register_module(force=True)
class MMDistributedDataParallel(_MMDistributedDataParallel):

    def __init__(self, module, detect_anomalous_params: bool = False, **kwargs):
        super().__init__(module=module, **kwargs)
        self.detect_anomalous_params = detect_anomalous_params

    def train_step(
        self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper
    ) -> Dict[str, torch.Tensor]:
        """Interface for model forward, backward and parameters updating during
        training process.

        :meth:`train_step` will perform the following steps in order:

        - If :attr:`module` defines the preprocess method,
          call ``module.preprocess`` to pre-processing data.
        - Call ``module.forward(**data)`` and get losses.
        - Parse losses.
        - Call ``optim_wrapper.optimizer_step`` to update parameters.
        - Return log messages of losses.

        Args:
            data (dict or tuple or list): Data sampled from dataset.
            optim_wrapper (OptimWrapper): A wrapper of optimizer to
                update parameters.

        Returns:
            Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
        """
        # Enable automatic mixed precision training context.
        # with optim_wrapper.optim_context(self):
        #     data = self.module.data_preprocessor(data, training=True)
        #     losses = self.module._run_forward(data, mode="loss")
        # parsed_loss, log_vars = self.module.parse_losses(losses)
        # optim_wrapper.update_params(parsed_loss)
        # if self.detect_anomalous_params:
        #     detect_anomalous_params(parsed_loss, model=self)
        # return log_vars
        return self.module.train_step(data, optim_wrapper)

    def _run_forward(
        self, data: Union[dict, tuple, list], mode: str
    ) -> Union[Dict[str, torch.Tensor], list]:
        raise NotImplementedError(
            "Don't call FSDPWrapper's _run_forward directly, call its module's _run_forward"
        )
