import torch.nn as nn
from src.models.model_components import build_encoder, build_projector


class SDMI(nn.Module):
    def __init__(self, args):
        super(SDMI, self).__init__()

        self.E_encoder = build_encoder(args)
        self.E_projector = build_projector(args)
        
        self.M_encoder = build_encoder(args)
        self.M_projector = build_projector(args)

        self.E_loss_history = []
        self.M_loss_history = []
        self.E_linear_probing_history = {}
        self.M_linear_probing_history = {}


    def forward(self, view_1, view_2):
        E_representation_1 = self.E_projector(self.E_encoder(view_1))
        E_representation_2 = self.E_projector(self.E_encoder(view_2))

        M_representation_1 = self.M_projector(self.M_encoder(view_1))
        M_representation_2 = self.M_projector(self.M_encoder(view_2))

        return E_representation_1, E_representation_2, M_representation_1, M_representation_2
    