import torch
import torch.nn as nn
import torch.nn.functional as F



def spectral_projector(prev_dim, dim, hidden_dim=2048):
    return nn.Sequential(
        nn.Linear(prev_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, dim, bias=False),
        nn.BatchNorm1d(dim)
    )