#!/usr/bin/python3

'''
Standardization Layer
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

import torch


__all__ = [
    'Standardization',
]


class Standardization(torch.nn.modules.module.Module):
    r'''Applies a standardization transformation to the incoming data: :math:`y = a * (x - b)`

    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

    On certain ROCm devices, when using float16 inputs this module will use
    :ref:`different precision<fp16_on_mi200>` for backward.

    Args:
        features: size of input/output sample

    Shape:
        - input:  :math:`(H_{in})` where :math:`H_{in} = \text{features}`.
        - output: :math:`(H_{out})` where :math:`H_{out} = \text{features}`.

    Attributes:
        weight: the learnable weights of the module of shape :math:`(\text{features})`.
                The values are initialized as ones.
        bias:   the learnable bias of the module of shape :math:`(\text{features})`.
                The values are initialized as zeros.

    Examples::

        >>> m = Standization(10)
        >>> input = torch.randn(10)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([10])
    '''
    __constants__ = ['features']
    features: int
    weight: torch.Tensor

    def __init__(self, features: int, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.features = features
        self.weight = torch.nn.parameter.Parameter(torch.ones(features, **factory_kwargs))
        self.bias = torch.nn.parameter.Parameter(torch.zeros(features, **factory_kwargs))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        '''
        Return: Standardization(input)
        '''
        return self.weight * (input - self.bias)

    def extra_repr(self) -> str:
        return f'features={self.features}'
