import torch
import torch.nn as nn
from models import objectives

class Finetune_cont(nn.Module):

    def __init__(self, model: nn.Module, device, **kwargs) -> None:
        super().__init__()
        self.backbone = model
        self.device = device

        self._req_penalty = False
        self._req_opt = False

    def forward(self, inputs):
        if 'retrieval' in inputs and inputs['retrieval']:  # For retrieval task
            return self.backbone(inputs)

        inputs['modality_token'] = True
        inputs['masked_audio'] = True
        inputs['masked_visual'] = True

        output = self.backbone(inputs)
        output = objectives.compute_vacon(output, loss_weight=self.backbone.contrast_loss_weight, tau=self.backbone.contrast_tau)

        return output
