
"""Normalizer for score models."""

from __future__ import annotations

from abc import abstractmethod
from typing import Any, Literal

import torch
from torch import nn
from torch.types import Number


NormalizeFunction = Literal['affine', 'scale', 'translate', 'identity']
NormalizerType = Literal['RunningMeanStd', 'ExponentialMovingAverage']


class Normalizer(nn.Module):
    """Normalize input to have zero mean and unit variance."""

    mean: torch.Tensor
    var: torch.Tensor
    count: torch.LongTensor
    normalize_function: NormalizeFunction

    def __init__(
        self,
        normalize_function: NormalizeFunction,
        shape: tuple[int, ...],
        device: torch.device | str | None = None,
    ) -> None:
        """Initialize."""
        super().__init__()
        if normalize_function not in {'affine', 'scale', 'translate', 'identity'}:
            raise ValueError(
                f'Invalid normalization function type: {normalize_function}. ',
                'Expected one of "affine", "scale", "translate", "identity".',
            )
        self.normalize_function = normalize_function
        self.register_buffer('mean', torch.zeros(shape, device=device))
        self.register_buffer('var', torch.ones(shape, device=device))
        self.register_buffer('count', torch.zeros(1, dtype=torch.long, device=device))

    @abstractmethod
    def update(self, data: torch.Tensor) -> None:
        """Update mean and variance."""
        raise NotImplementedError

    @property
    def std(self) -> torch.Tensor:
        """Return standard deviation."""
        return self.var.sqrt()

    def set_mean_var(
        self,
        mean: torch.Tensor | list[float] | tuple[float, ...] | None,
        var: torch.Tensor | list[float] | tuple[float, ...] | None,
    ) -> None:
        """Set mean and variance."""
        mean = (
            torch.as_tensor(mean, dtype=self.mean.dtype, device=self.mean.device)
            if mean is not None
            else self.mean
        )
        var = (
            torch.as_tensor(var, dtype=self.var.dtype, device=self.var.device)
            if var is not None
            else self.var
        )

        assert mean.shape == self.mean.shape
        assert var.shape == self.var.shape

        self.mean = mean
        self.var = var

    def forward(
        self,
        data: torch.Tensor,
        epsilon: Number = 1e-8,
    ) -> torch.Tensor:
        """Update and normalize input."""
        if self.training:
            self.update(data)
        return self.normalize(data, epsilon=epsilon)

    def normalize(
        self,
        data: torch.Tensor,
        epsilon: Number = 1e-8,
    ) -> torch.Tensor:
        """Normalize input."""
        if self.normalize_function == 'affine':
            return (data - self.mean.detach()) / (self.std.detach() + epsilon)
        if self.normalize_function == 'scale':
            return data / (self.std.detach() + epsilon)
        if self.normalize_function == 'translate':
            return data - self.mean.detach()
        if self.normalize_function == 'identity':
            return data
        raise ValueError(
            f'Invalid normalization function type: {self.normalize_function}. ',
            'Expected one of "affine", "scale", "translate", "identity".',
        )

    @classmethod
    def instantiate(
        cls,
        normalizer_type: NormalizerType | None,
        normalize_function: NormalizeFunction,
        shape: tuple[int, ...],
        device: torch.device | str | None = None,
        **kwargs: Any,
    ) -> Normalizer:
        """Get a normalizer."""
        if normalizer_type == 'RunningMeanStd':
            return RunningMeanStd(
                normalize_function,
                shape=shape,
                device=device,
            )
        if normalizer_type == 'ExponentialMovingAverage':
            return ExponentialMovingAverage(
                normalize_function,
                shape=shape,
                device=device,
                **kwargs,
            )
        if normalizer_type is None:
            return IdentityNormalizer(
                normalize_function,
                shape=shape,
                device=device,
            )
        raise ValueError(
            f'Invalid normalization function type: {normalizer_type}. '
            'Expected one of "RunningMeanStd", "ExponentialMovingAverage".',
        )


class RunningMeanStd(Normalizer):
    """Running mean and standard deviation."""

    def update(self, data: torch.Tensor) -> None:
        """Update mean and variance."""
        batch_mean = data.mean(dim=0)
        batch_var = data.var(dim=0)
        batch_count = data.size(0)

        delta = batch_mean - self.mean
        total_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / total_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m2 = (  # pylint: disable=invalid-name
            m_a + m_b + torch.square(delta) * (self.count * batch_count / total_count)
        )
        new_var = m2 / total_count

        self.mean = new_mean
        self.var = new_var
        self.count = total_count


class ExponentialMovingAverage(Normalizer):
    """Exponential moving average."""

    def __init__(
        self,
        normalize_function: NormalizeFunction,
        shape: tuple[int, ...],
        device: torch.device | str | None = None,
        momentum: float = 0.9,
    ) -> None:
        super().__init__(normalize_function, shape=shape, device=device)
        self.momentum = momentum

    def update(self, data: torch.Tensor) -> None:
        """Update mean and variance."""
        batch_mean = data.mean(dim=0)
        batch_var = data.var(dim=0)
        batch_count = data.size(0)

        self.mean = self.momentum * self.mean + (1.0 - self.momentum) * batch_mean
        self.var = self.momentum * self.var + (1.0 - self.momentum) * batch_var
        self.count += batch_count  # pylint: disable=no-member


class IdentityNormalizer(Normalizer):
    """Identity normalizer."""

    def update(self, data: torch.Tensor) -> None:
        """Update mean and variance."""
        self.count += data.size(0)  # pylint: disable=no-member
