from torch import nn
import torch
import math


class DetAdaptor(nn.Module):
    def __init__(self):
        super(DetAdaptor, self).__init__()
        self.etta = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
        )
        self.temperature = nn.Parameter(torch.tensor(1.0))

    def forward(self, input):
        return self.etta(input)

    def loss(self, mu_t, mu_i):
        log_ll = self.matrixwise_ll(mu_t, mu_i)
        labels = torch.arange(log_ll.size(0), device=log_ll.device)
        l1 = nn.functional.cross_entropy(log_ll, labels)
        l2 = nn.functional.cross_entropy(log_ll.t(), labels)
        return (l1 + l2)/2

    # compute the log-likelihood between each individual image and text within the batch
    # this method should be implemented in the child classes
    def matrixwise_ll(self, mu_t, mu_i):
        mu_i = mu_i / torch.linalg.norm(mu_i, dim=-1, keepdim=True)
        mu_t = mu_t / torch.linalg.norm(mu_t, dim=-1, keepdim=True)
        log_ll = mu_i @ (mu_t.t()) * self.temperature
        return log_ll
