import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel


class ModelOutput:
    loss = None
    logits = None


# https://www.kaggle.com/code/yiweiwangau/cifar-100-resnet-pytorch-75-17-accuracy
class ResNet9(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.conv1 = self.__conv_block(in_channels, 64)
        self.conv2 = self.__conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(
            self.__conv_block(128, 128), self.__conv_block(128, 128)
        )

        self.conv3 = self.__conv_block(128, 256, pool=True)
        self.conv4 = self.__conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(
            self.__conv_block(512, 512), self.__conv_block(512, 512)
        )
        self.conv5 = self.__conv_block(512, 1028, pool=True)
        self.res3 = nn.Sequential(
            self.__conv_block(1028, 1028), self.__conv_block(1028, 1028)
        )
        self.pool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.conv5(out)
        out = self.res3(out) + out
        out = self.pool(out)
        out = self.flatten(out)
        return out

    def __conv_block(self, in_channels, out_channels, pool=False):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if pool:
            layers.append(nn.MaxPool2d(2))
        return nn.Sequential(*layers)


class BertRepr(nn.Module):
    def __init__(self, model_name="bert-base-uncased", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bert = BertModel.from_pretrained(model_name)
        classifier_dropout = (
            self.bert.config.classifier_dropout
            if self.bert.config.classifier_dropout is not None
            else self.bert.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)

    def forward(self, x):
        out = self.bert(**x)[1]
        return self.dropout(out)
