# Modified based on https://github.com/facebookresearch/simsiam

import copy

import torch
import torch.nn as nn


class SimCLR(nn.Module):
    """
    Build a SimCLR model.
    """
    def __init__(self, base_encoder, dim=2048):
        """
        dim: feature dimension (default: 2048)
        """
        super(SimCLR, self).__init__()

        # create the encoder
        # num_classes is the output fc dimension
        self.encoder = base_encoder(num_classes=dim)

        # build a 3-layer projector
        prev_dim = self.encoder.fc.weight.shape[1]
        self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # first layer
                                        nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # second layer
                                        self.encoder.fc,
                                        nn.BatchNorm1d(dim, affine=False)) # output layer
        self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN

        # output embedding of avgpool
        self.encoder.avgpool.register_forward_hook(self._get_avg_output())
        self.embedding = None

    def _get_avg_output(self):
        def hook(model, input, output):
            self.embedding = output.detach()
        return hook

    def forward(self, x1, x2=None):
        """
        Input:
            x1: first views of images
            x2: second views of images
        Output:
            z1, z2
        """
        if self.training:
            z1 = self.encoder(x1) # NxC
            z2 = self.encoder(x2) # NxC
            return z1, z2
        else:
            _ = self.encoder(x1)
            return self.embedding.squeeze()
