from ..BaseAdaptor import Adaptor
from .modules import DetAdaptor
import torch


class DetAsymVLM(Adaptor):
    def __init__(self, distribution=None):
        super(DetAsymVLM, self).__init__()
        self.image_adaptor = None
        self.text_adaptor = DetAdaptor()

    def forward(self, z_i, z_t):
        z_t_prime = self.text_adaptor(z_t)
        return z_i, z_t_prime

    def loss(self, z_i_prime, z_t_prime):
        return self.text_adaptor.loss(z_t_prime, z_i_prime)

    # methods for inference: output is the adapted representation
    # in the form of means and uncertainties
    def adapt_text(self, z_t):
        z_t = self.text_adaptor(z_t)
        # normalize the text representation
        z_t = z_t / z_t.norm(dim=-1, keepdim=True)
        return z_t, torch.ones(z_t.shape[0], device=z_t.device)

    def adapt_image(self, z_i):
        return z_i, torch.ones(z_i.shape[0], device=z_i.device)

    def log_likelihood(self, z_i, z_t):
        mu_i, _ = self.adapt_image(z_i)
        mu_t, _ = self.adapt_text(z_t)
        # normalize the vectors
        mu_i = mu_i / torch.linalg.norm(mu_i, dim=-1, keepdim=True)
        mu_t = mu_t / torch.linalg.norm(mu_t, dim=-1, keepdim=True)
        return mu_i @ (mu_t.t())
