import torch.nn as nn
from src.models.model_components import build_encoder, build_projector


class BarlowTwins(nn.Module):
    def __init__(self, args):
        super(BarlowTwins, self).__init__()

        self.online_encoder = build_encoder(args)
        self.projector = build_projector(args)
        self.bn = nn.BatchNorm1d(args.feature_dim, affine=False)

        self.loss_history = []
        self.linear_probing_history = {}


    def forward(self, view_1, view_2):
        representation_1 = self.bn(self.projector(self.online_encoder(view_1)))
        representation_2 = self.bn(self.projector(self.online_encoder(view_2)))

        return representation_1, representation_2
