from __future__ import annotations



from typing import Dict, Optional, Tuple



import torch

import torch.nn as nn





_DTYPE_MAP = {

    "float32": torch.float32,

    "float16": torch.float16,

    "bfloat16": torch.bfloat16,

}





def resolve_dtype(value: Optional[str]) -> Optional[torch.dtype]:

    if value is None:

        return None

    value = value.lower()

    if value in ("model", "param", "none"):

        return None

    if value not in _DTYPE_MAP:

        raise ValueError(f"Unsupported EMA dtype: {value}")

    return _DTYPE_MAP[value]





class EMAHelper:

    def __init__(self, decay: float, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):

        if not 0.0 < decay < 1.0:

            raise ValueError(f"EMA decay must be in (0, 1), got {decay}")

        self.decay = decay

        self.device = device

        self.dtype = dtype

        self.shadow: Dict[str, torch.Tensor] = {}

        self.num_updates = 0

        self._backup: Optional[Dict[str, torch.Tensor]] = None



    def _target_device_dtype(self, param: torch.Tensor) -> Tuple[torch.device, torch.dtype]:

        device = self.device if self.device is not None else param.device

        dtype = self.dtype if self.dtype is not None else param.dtype

        return device, dtype



    def _get_or_init_shadow(self, name: str, param: torch.Tensor) -> torch.Tensor:

        device, dtype = self._target_device_dtype(param)

        if name not in self.shadow:

            self.shadow[name] = param.detach().clone().to(device=device, dtype=dtype)

        else:

            shadow = self.shadow[name]

            if shadow.device != device or shadow.dtype != dtype:

                self.shadow[name] = shadow.to(device=device, dtype=dtype)

        return self.shadow[name]



    def register(self, module: nn.Module) -> None:

        with torch.no_grad():

            for name, param in module.named_parameters():

                if not param.requires_grad:

                    continue

                self._get_or_init_shadow(name, param)



    def update(self, module: nn.Module) -> None:

        with torch.no_grad():

            for name, param in module.named_parameters():

                if not param.requires_grad:

                    continue

                shadow = self._get_or_init_shadow(name, param)

                device, dtype = self._target_device_dtype(param)

                data = param.detach()

                if data.device != device or data.dtype != dtype:

                    data = data.to(device=device, dtype=dtype)

                shadow.mul_(self.decay).add_(data, alpha=1.0 - self.decay)

        self.num_updates += 1



    def store(self, module: nn.Module) -> None:

        self._backup = {}

        for name, param in module.named_parameters():

            if name in self.shadow:

                self._backup[name] = param.detach().clone()



    def copy_to(self, module: nn.Module) -> None:

        for name, param in module.named_parameters():

            if name in self.shadow:

                shadow = self.shadow[name]

                param.data.copy_(shadow.to(device=param.device, dtype=param.dtype))



    def restore(self, module: nn.Module) -> None:

        if self._backup is None:

            return

        for name, param in module.named_parameters():

            if name in self._backup:

                param.data.copy_(self._backup[name])

        self._backup = None



    def state_dict(self) -> Dict[str, torch.Tensor]:

        return {

            "decay": torch.tensor(self.decay, dtype=torch.float32),

            "num_updates": torch.tensor(self.num_updates, dtype=torch.int64),

            "shadow": {name: tensor.detach().cpu() for name, tensor in self.shadow.items()},

        }



    def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

        decay = state_dict.get("decay")

        num_updates = state_dict.get("num_updates")

        if decay is not None:

            self.decay = float(decay)

        if num_updates is not None:

            self.num_updates = int(num_updates)

        shadow = state_dict.get("shadow")

        if shadow is not None:

            self.shadow = {name: tensor.clone() for name, tensor in shadow.items()}

