import torch.nn as nn
from src.models.model_components import build_encoder, build_projector


class JMI(nn.Module):
    def __init__(self, args):
        super(JMI, self).__init__()

        self.online_encoder = build_encoder(args)
        self.projector = build_projector(args)

        self.loss_history = []
        self.linear_probing_history = {}


    def forward(self, view_1, view_2):
        representation_1 = self.projector(self.online_encoder(view_1))
        representation_2 = self.projector(self.online_encoder(view_2))

        return representation_1, representation_2
