import torch
from torchvision.models.swin_transformer import SwinTransformer

class Swin_T(SwinTransformer):
    def __init__(self):
        
        super(Swin_T, self).__init__(
                                   patch_size=[4, 4],
                                    embed_dim=96,
                                    depths=[2, 2, 6, 2],
                                    num_heads=[3, 6, 12, 24],
                                    window_size=[7, 7],
                                    stochastic_depth_prob=0.2,
                                    # num_heads=12, hidden_dim=768, mlp_dim=3072,
                                    num_classes=1000)

    def forward(self, x):

        x = self.features(x)
        x = self.norm(x)
        x = self.permute(x)
        x = self.avgpool(x)
        feature = self.flatten(x)
        logit = self.head(feature)
        # print(feature.shape, logit.shape)
        return logit, feature

