import torch
import torch.nn as nn
import torch.optim as optim


class Identity(nn.Module):
    def __init__(self, temperature=None):
        super().__init__()
        self.temperature = nn.Parameter(torch.tensor([temperature]).cuda()) if temperature is not None else (
            nn.Parameter(torch.tensor([1.0]).cuda()))

    def train(self, logits, labels):
        return 0.0, 0.0

    def forward(self, logits, softmax=True):
        if softmax:
            softmax = nn.Softmax(dim=1)
            return softmax(logits / self.temperature)

        return logits / self.temperature
