import torch.nn as nn
from torchvision.models import (
    resnet50,
    ResNet50_Weights,
    resnet101,
    ResNet101_Weights,
    swin_v2_b,
    Swin_V2_B_Weights,
)
from src.model.utils import ModelOutput, ResNet9, BertRepr


class MnistConvNet(nn.Module):
    def __init__(self):
        super(MnistConvNet, self).__init__()
        self.name = "conv_net"
        self.repr_model = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),
            nn.Flatten(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 10)
        )
        self.in_features = 9216

    def forward(self, x):
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out


class CIFARConvNet(nn.Module):
    def __init__(self):
        super(CIFARConvNet, self).__init__()
        self.name = "conv_net"
        self.repr_model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
        self.in_features = 4096

    def forward(self, x):
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out


class CIFAR100ConvNet(nn.Module):
    def __init__(self):
        super(CIFAR100ConvNet, self).__init__()
        self.name = "conv_net"
        self.repr_model = nn.Sequential(ResNet9(3), nn.Dropout(0.2))
        self.classifier = nn.Sequential(nn.Linear(1028, 100))
        self.in_features = 1028

    def forward(self, x):
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out


class BertClassifier(nn.Module):
    def __init__(self, num_labels, model_name="bert-base-uncased", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = "bert-base"
        self.repr_model = BertRepr(model_name=model_name)
        self.classifier = nn.Linear(in_features=768, out_features=num_labels, bias=True)
        self.in_features = 768

    def forward(self, x):
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out


class ResNet50Classifier(nn.Module):
    def __init__(self, num_labels, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = "resnet50"
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.repr_model = nn.Sequential(*list(model.children())[:-1], nn.Flatten())
        self.classifier = nn.Linear(
            in_features=2048, out_features=num_labels, bias=True
        )
        self.in_features = 2048
        if num_labels == 1000:
            self.classifier = model.fc

        self.preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()

    def forward(self, x):
        # x = self.preprocess(x)
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out


class ResNet101Classifier(nn.Module):
    def __init__(self, num_labels, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = "resnet101"
        model = resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)
        self.repr_model = nn.Sequential(*list(model.children())[:-1], nn.Flatten())
        self.classifier = nn.Linear(
            in_features=2048, out_features=num_labels, bias=True
        )
        self.in_features = 2048
        if num_labels == 1000:
            self.classifier = model.fc
        self.preprocess = ResNet101_Weights.IMAGENET1K_V2.transforms()

    def forward(self, x):
        # x = self.preprocess(x)
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out


class SwinClassifier(nn.Module):
    def __init__(self, num_labels, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = "swin"
        model = swin_v2_b(weights=Swin_V2_B_Weights.DEFAULT)
        self.repr_model = nn.Sequential(*list(model.children())[:-1])
        self.classifier = nn.Linear(
            in_features=1024, out_features=num_labels, bias=True
        )
        self.in_features = 1024
        if num_labels == 1000:
            self.classifier = model.head

        self.preprocess = Swin_V2_B_Weights.DEFAULT.transforms()

    def forward(self, x):
        # x = self.preprocess(x)
        repr = self.repr_model(x)
        output = self.classifier(repr)
        out = ModelOutput()
        out.logits = output
        return out
