import torch
from torch import nn
from math import log


class LearnableLogCoeffient(nn.Module):
    def __init__(self, p0=log(0.2)):
        super(LearnableLogCoeffient, self).__init__()
        self.p = nn.Parameter(torch.tensor(p0), requires_grad=True)

    def forward(self, x, grad=True):
        p = self.__get(grad)
        return p * x

    def __get(self, grad=True):
        # v = nn.functional.softplus(self.p)
        # v = torch.clamp(torch.exp(self.p), max=10.0)
        v = torch.exp(self.p) if self.p < 0 else self.p + 1
        v = torch.clamp(v, max=10.0)
        if not grad:
            v = v.detach()
        return v

    def __mul__(self, other):
        p = self.__get(False)
        return p * other

    def value(self):
        return self.__get(False).item()


if __name__ == '__main__':
    pass
