import torch

from experiments.models.stochastic_model import StochasticModel
from experiments.utils import print_module

from torchvision.models import resnet
from torch import nn

from XXX.uib.utils.safe_module import SafeModule


class DeterministicCifar10Resnet(SafeModule):
    def __init__(self, resnet_factory=resnet.resnet18, *, pretrained=False, capacity=10):
        super().__init__()

        self.resnet = resnet_factory(pretrained=pretrained, num_classes=capacity)
        self.resnet.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = torch.nn.Identity()

    def safe_forward(self, x):
        x = self.resnet(x)
        # x = F.log_softmax(x, dim=1)

        return x


class StochasticCifar10Resnet(StochasticModel):
    def __init__(self, *, resnet_factory, num_samples, dropout_rate, fc_dropout_rate, pretrained=False, capacity=10):
        super().__init__(num_samples)

        self.resnet = resnet_factory(
            pretrained=pretrained, num_classes=capacity, dropout_rate=dropout_rate, fc_dropout_rate=fc_dropout_rate
        )
        self.resnet.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = torch.nn.Identity()

    def stochastic_forward_impl(self, x):
        x = self.resnet(x)
        # x = F.log_softmax(x, dim=1)

        return x


if __name__ == "__main__":
    print_module(DeterministicCifar10Resnet())
