from mmengine.model.base_model import BaseDataPreprocessor as _BaseDataPreprocessor
from torch import nn
import torch
from mmhug.registry import MODELS


@MODELS.register_module(force=True)
class BaseDataPreprocessor(_BaseDataPreprocessor):
    def to(self, *args, **kwargs) -> nn.Module:
        """Overrides this method to set the :attr:`device`

        Returns:
            nn.Module: The model itself.
        """

        # Since Torch has not officially merged
        # the npu-related fields, using the _parse_to function
        # directly will cause the NPU to not be found.
        # Here, the input parameters are processed to avoid errors.
        if args and isinstance(args[0], str) and "npu" in args[0]:
            args = tuple([list(args)[0].replace("npu", f"npu:{torch.npu.current_device()}")])
        if kwargs and "npu" in str(kwargs.get("device", "")):
            kwargs["device"] = kwargs["device"].replace(
                "npu", f"npu:{torch.npu.current_device()}"
            )

        device = torch._C._nn._parse_to(*args, **kwargs)[0]
        if device is not None:
            self._device = torch.device(device)
        return super(_BaseDataPreprocessor, self).to(*args, **kwargs)
