from torchvision.models import resnet18
import torch.nn as nn


class Resnet18CelebA(nn.Module):

    def __init__(
            self,
            d_output,
            **kwargs,
    ):
        super().__init__()
        if 'l_output' in kwargs and kwargs['l_output'] > 1:
            d_output = kwargs['l_output']

        self.resnet = resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(512, d_output)

    def forward(self, x, *args, **kwargs):
        # BSC -> BCS
        x = x.transpose(1, 2)
        # BCS -> BCHW
        x = x.view(x.shape[0], 3, 178, 218)
        return self.resnet.forward(x)

class Resnet18Pathfinder(nn.Module):

    def __init__(
            self,
            d_input,
            d_output,
            resolution=128,
            **kwargs,
    ):
        super().__init__()
        print("ResNet kwargs", kwargs)
        if 'l_output' in kwargs and kwargs['l_output'] > 1:
            d_output = kwargs['l_output']

        self.d_input = d_input
        self.resolution = resolution
        self.resnet = resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(512, d_output)

    def forward(self, x, *args, **kwargs):
        # BSC -> BCS
        x = x.transpose(1, 2)
        # BCS -> BCHW
        x = x.view(x.shape[0], self.d_input, self.resolution, self.resolution)
        if self.d_input == 1:
            x = x.repeat(1, 3, 1, 1)
        elif self.d_input == 3:
            pass
        else: raise NotImplementedError
        return self.resnet.forward(x)
