# -*- coding: utf-8 -*-

from typing import Optional

import torch
import torch.nn as nn


class Scaler(nn.Module):
    def __init__(self, scale: Optional[float] = None):
        super().__init__()

        if scale is None:
            self.register_parameter("scale", nn.Parameter(torch.tensor(1.0)))
        else:
            self.scale = scale

    def forward(self, input):
        return input * self.scale

    # def extra_repr(self):
    #     learnable = isinstance(self.scale, nn.Parameter)
    #     return f"scale={self.scale:.4f}, learnable={learnable}"
    def extra_repr(self):

        learnable = isinstance(self.scale, nn.Parameter)
        try:
            return f"scale={self.scale:.4f}, learnable={learnable}"
        except TypeError:
            return f"scale={float(self.scale):.4f}, learnable={learnable}"