import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms as T


class SwinTransformer(nn.Module):
    def __init__(self, num_features=512):
        super(SwinTransformer, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224')
        self.num_features = num_features
        self.feat = nn.Linear(1024, num_features) if num_features > 0 else None

    def forward(self, x):
        x = self.model.forward_features(x)
        if not self.feat is None:
            x = self.feat(x)
        x = F.normalize(x)
        return x


def load_preprocessor_isr(h=224, w=224):
    transform = T.Compose([
        T.Resize((h, w)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform


def load_model_isr(path):
    weight = torch.load(path)
    model = SwinTransformer(num_features=512).cuda()
    model.eval()
    model.load_state_dict(weight['state_dict'], strict=True)
    return model
