import torch
import torch.nn as nn
import torch.nn.functional as F


class VirtualSoftmax(nn.Module):
    def __init__(self, input_dim, num_classes, dtype=torch.float32):
        super(VirtualSoftmax, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.dtype = dtype

        self.kernel = nn.Parameter(torch.empty(input_dim, num_classes, dtype=dtype))
        nn.init.xavier_normal_(self.kernel)

    def forward(self, inputs, labels=None):
        kernel = self.kernel

        if self.dtype == torch.float16:
            kernel = kernel.half()

        # Calculate normal WX (output of final FC)
        WX = torch.matmul(inputs, kernel)

        if self.training:
            # Get label indices to get W_yi
            W_yi = kernel[:, labels]
            W_yi_norm = W_yi.norm(dim=0)
            X_i_norm = inputs.norm(dim=1)

            # Calculate WX_virt => virtual class output
            WX_virt = W_yi_norm * X_i_norm
            WX_virt = torch.clamp(WX_virt, min=1e-10, max=15.0)  # for numerical stability
            WX_virt = WX_virt.unsqueeze(1)

            # New WX is normal WX + WX_virt (concatenated to the feature dimension)
            WX_new = torch.cat([WX, WX_virt], dim=1)
            return WX_new
        else:
            return WX


if __name__ == '__main__':
    # Example usage
    input_dim = 128  # example input dimension
    num_classes = 10  # example number of classes
    inputs = torch.randn(32, input_dim)  # example batch of inputs
    labels = torch.randint(0, num_classes, (32,))  # example batch of labels

    virtual_softmax = VirtualSoftmax(input_dim, num_classes)
    logits = virtual_softmax(inputs, labels, mode='train')
    print(logits.shape)
