import torch
import torch.nn as nn
from torchvision.models.video import swin3d_t


# Ft model loading function.
def load_model(arch='swin_t', saved_model_file=None, num_classes=400, kin_pretrained=False):
    if arch == 'swin_t':
        model = wrapper_swin(num_classes=num_classes, pretrained=kin_pretrained, size='tiny')
    else:
        print(f'Architecture {arch} invalid for model. Try \'swin_t\'.')
        return

    # Load in saved model.
    if saved_model_file:
        try:
            saved_dict = torch.load(saved_model_file)
            model.load_state_dict(saved_dict['model_state_dict'], strict=True)
            print(f'model loaded from {saved_model_file} successfully!')
        except:
            print(f'Error loading model from {saved_model_file}.')
            print(f'model freshly initialized! Pretrained: {kin_pretrained}')
    else:
        print(f'model freshly initialized! Pretrained: {kin_pretrained}')

    return model


# MLP layer.
class MLP(nn.Module):
    def __init__(self, initial_embedding_size=2048, final_embedding_size=128, use_normalization=True):
        super(MLP, self).__init__()
        self.initial_embedding_size = initial_embedding_size
        self.final_embedding_size = final_embedding_size
        self.use_normalization = use_normalization
        self.fc1 = nn.Linear(self.initial_embedding_size, self.initial_embedding_size, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(self.initial_embedding_size, self.final_embedding_size, bias=True)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = nn.functional.normalize(self.fc2(x), p=2, dim=1)
        return x


class wrapper_swin(nn.Module):

    def __init__(self, num_classes=400, pretrained=True, size='tiny'):

        super(wrapper_swin, self).__init__()
        self.backbone = swin3d_t(weights='DEFAULT' if pretrained else None)
        self.head = self.backbone.head
        if num_classes != 400:
            self.head = nn.Linear(768, num_classes)
        self.backbone.head = nn.Identity()
        self.proj = MLP(768, 128)


    def forward(self, x):
        feature = self.backbone(x)
        pred = self.head(feature)
        feature = self.proj(feature)
        return pred, feature


if __name__ == '__main__':
    model = load_model(arch='swin_t', num_classes=101, kin_pretrained=True)
    print(model)
    model.eval()
    model.cuda()
    inputs = torch.rand((4, 3, 16, 224, 224)).cuda()

    with torch.no_grad():
        output, feat = model(inputs)

    print(f'Output shape is: {output.shape}')
    print(f'Feature shape is: {feat.shape}')
