import torch
from torchvision.models.vision_transformer import VisionTransformer

class VIT_B_16(VisionTransformer):
    def __init__(self):
        
        super(VIT_B_16, self).__init__(image_size=224, patch_size=16, num_layers=12,
                 num_heads=12, hidden_dim=768, mlp_dim=3072,
                 num_classes=1000)

    def forward(self, x):

        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)
        # print(x.shape)

        # Classifier "token" as used by standard language architectures
        feat = x[:, 0]
        
        
        logit = self.heads(feat)

        return logit, feat

