import torch

__all__ = ['NoneScaler', 'TempScaler', 'VectorScaler', 'AuxScaler']


class NoneScaler(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class TempScaler(torch.nn.Module):
    def __init__(self, init_T=1.0):
        super().__init__()
        self.T = torch.nn.Parameter(torch.tensor([init_T]))

    def forward(self, x):
        return x/self.T


class VectorScaler(torch.nn.Module):
    def __init__(self, num_classes, init_T=1.0):
        super().__init__()
        self.T = torch.nn.Parameter(torch.tensor([init_T])*num_classes)

    def forward(self, x):
        return x/self.T


class AuxScaler(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        hidden = num_classes
        self.linear1 = torch.nn.Linear(num_classes, hidden, bias=True)
        self.linear2 = torch.nn.Linear(hidden, num_classes, bias=True)

        self.linear1.weight.data = torch.eye(num_classes)
        self.linear2.weight.data = torch.eye(num_classes)

        self.linear1.bias.data = torch.zeros_like(self.linear1.bias.data)
        self.linear2.bias.data = torch.zeros_like(self.linear2.bias.data)

    def forward(self, x):
        y = self.linear1(x)
        y = torch.nn.functional.leaky_relu(y)
        y = self.linear2(y)

        return y
