from typing import Any

import torch


class SafeModule(torch.nn.Module):
    dtype: Any
    tdevice: torch.device

    def __init__(self, dtype=None, device=None):
        super().__init__()

        self.dtype = dtype or torch.float32
        self.tdevice = device or torch.device("cpu")

    def convert_tensor(self, tensor: torch.Tensor, *, device_only=False, non_blocking=False):
        return tensor.to(device=self.tdevice, dtype=self.dtype if not device_only else None, non_blocking=non_blocking)

    def safe_forward(self, *inputs, **kwargs):
        return inputs

    def set_device(self, device):
        for module in self.modules():
            if isinstance(module, SafeModule):
                module.tdevice = device

        if device == torch.device("cpu"):
            return super().cpu()
        return super().cuda(device)

    def cuda(self, device=None):
        return self.set_device(device if device else torch.device("cuda"))

    def cpu(self):
        return self.set_device(torch.device("cpu"))

    def forward(self, *inputs, **kwargs):
        inputs = [
            input.to(device=self.tdevice, dtype=self.dtype) if isinstance(input, torch.Tensor) else input
            for input in inputs
        ]
        kwargs = {
            name: arg.to(device=self.tdevice, dtype=self.dtype) if isinstance(arg, torch.Tensor) else arg
            for name, arg in kwargs.items()
        }

        return self.safe_forward(*inputs, **kwargs)

    def type(self, dst_type):
        for module in self.modules():
            if isinstance(module, SafeModule):
                module.dtype = dst_type

        # Now update the actual parameters and buffers.
        super().type(dst_type)
        return self

    def float(self):
        return self.type(torch.float32)

    def double(self):
        return self.type(torch.float64)

    def half(self):
        return self.type(torch.float16)

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if device is not None:
            self.tdevice = device

        if dtype is not None:
            self.dtype = dtype

        return super().to(*args, **kwargs)
