import torch
from torch import nn

from torchvision.models import vgg16, resnet18, resnet50, convnext_tiny
from torchvision.models import VGG16_Weights, ResNet18_Weights, ResNet50_Weights
from torchvision.models.mobilenetv2 import MobileNetV2
from torchvision.models.maxvit import maxvit_t
from torchvision.models.vision_transformer import vit_b_16, vit_l_16
from torchvision.models.vision_transformer import ViT_B_16_Weights, ViT_L_16_Weights
from torchvision.models.swin_transformer import swin_v2_t, swin_v2_b
from torchvision.models.convnext import LayerNorm2d


class MaxVit(nn.Module):
    def __init__(self, output_size):
        super().__init__()

        self.model = maxvit_t()
        self.model.classifier = nn.Identity()

        self.pooling = nn.Sequential(
            nn.AdaptiveMaxPool2d(output_size=1),
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.LayerNorm(512, eps=1e-05, elementwise_affine=True),
            nn.Linear(in_features=512, out_features=512, bias=True),
            nn.Tanh()
        )

        self.clf = nn.Linear(512, out_features=output_size, bias=False)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)
        emb = self.pooling(emb)

        logits = self.clf(emb)

        return logits


class MobileNet_v2(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = MobileNetV2()
        self.model.classifier = nn.Identity()

        self.dropout = nn.Dropout(0.2)

        self.clf = nn.Linear(1280, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)
        emb = self.dropout(emb)

        logits = self.clf(emb)

        return logits


class VGG(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        model = vgg16()
        self.feature_extractor = model.features

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(25088, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(4096, output_size)
        )

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.feature_extractor(im)

        logits = self.classifier(emb)

        return logits


class ResNet_18(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = resnet18()
        self.model.fc = nn.Identity()
        self.preprocess = ResNet18_Weights.IMAGENET1K_V1.transforms(antialias=True)

        self.clf = nn.Linear(512, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)

        logits = self.clf(emb)

        return logits


class ResNet_50(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = resnet50() 
        self.model.fc = nn.Identity()

        self.dropout = nn.Dropout(0.5)
        self.clf = nn.Linear(2048, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)

        emb = self.dropout(emb)
        logits = self.clf(emb)
        logits = torch.nn.functional.normalize(logits)
        return logits


class VIT_B(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = vit_b_16()
        self.model.heads = nn.Identity()

        self.clf = nn.Linear(768, output_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)
        emb = self.dropout(emb)

        logits = self.clf(emb)

        return logits


class VIT_L(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = vit_l_16()
        self.model.heads = nn.Identity()

        self.clf = nn.Linear(1024, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)

        logits = self.clf(emb)

        return logits


class SWIN_T(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = swin_v2_t()
        self.model.head = nn.Identity()

        self.clf = nn.Linear(768, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)
        logits = self.clf(emb)

        return logits


class SWIN_B(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = swin_v2_b()
        self.model.head = nn.Identity()

        self.clf = nn.Linear(1024, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)
        logits = self.clf(emb)

        return logits


class Poor_Net(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 16, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Conv2d(16, 32, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Conv2d(32, 16, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Conv2d(16, 16, 3),
            nn.MaxPool2d(2),
        )

        self.clf = nn.Linear(16, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        # preprocess image
        emb = self.feature_extractor(im)

        emb = emb.reshape(emb.shape[0], 16, -1)
        emb = torch.mean(emb, dim=-1)

        logits = self.clf(emb)

        return logits


class ConvNext(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.model = convnext_tiny()
        self.model.classifier = nn.Identity()

        self.clf_input = nn.Sequential(
            LayerNorm2d(768),
            nn.Flatten()
        )

        self.clf = nn.Linear(768, output_size)

    def forward(self, im: torch.Tensor) -> torch.Tensor:
        emb = self.model(im)
        emb = self.clf_input(emb)
        logits = self.clf(emb)

        return logits
