import torch.nn as nn
from torch.nn.functional import normalize
from torchvision.models import resnet18


class ProjectionNet(nn.Module):
    def __init__(self, head_layers=None, num_classes=2, pretrained=False):
        super().__init__()

        if head_layers is None:
            head_layers = [512, 512, 512, 512, 512, 512, 512, 512, 128]

        # Create an MLP head as in the following:
        # - https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
        # - Not sure if this is the right architecture used in CutPaste.
        last_layer = 512
        sequential_layers = []
        for num_neurons in head_layers:
            sequential_layers.append(nn.Linear(last_layer, num_neurons))
            sequential_layers.append(nn.BatchNorm1d(num_neurons))
            sequential_layers.append(nn.ReLU(inplace=True))
            last_layer = num_neurons
        self.head = nn.Sequential(*sequential_layers)

        self.resnet18 = resnet18(pretrained=pretrained)
        self.resnet18.fc = nn.Identity()
        self.out = nn.Linear(last_layer, num_classes)

    def forward(self, x, emb_type='first', emb_norm=True):
        emb1 = self.resnet18(x)
        emb2 = self.head(emb1)
        logits = self.out(emb2)
        if emb_type == 'first':
            emb = emb1
        elif emb_type == 'second':
            emb = emb2
        else:
            raise ValueError(emb_type)
        if emb_norm:
            emb = normalize(emb, p=2, dim=1)
        return emb, logits

    def freeze_resnet(self):
        # Freeze the ResNet18 network.
        for param in self.resnet18.parameters():
            param.requires_grad = False

        # Unfreeze the MLP head.
        for param in self.resnet18.fc.parameters():
            param.requires_grad = True

    def unfreeze(self):
        # Unfreeze all.
        for param in self.parameters():
            param.requires_grad = True
