import torch.nn as nn
from src.models.model_components import build_encoder, build_projector, build_predictor


# Adapted from https://github.com/facebookresearch/simsiam
class SimSiam(nn.Module):
    def __init__(self, args):
        super(SimSiam, self).__init__()

        self.online_encoder = build_encoder(args)
        self.projector = build_projector(args)
        self.predictor = build_predictor(args)

        self.loss_history = []
        self.linear_probing_history = {}


    def forward(self, view_1, view_2):
        representation_1 = self.projector(self.online_encoder(view_1))
        representation_2 = self.projector(self.online_encoder(view_2))

        prediction_1 = self.predictor(representation_1)
        prediction_2 = self.predictor(representation_2)

        return prediction_1, prediction_2, representation_1.detach(), representation_2.detach()
    