import torch
import torch.nn as nn
from torch.nn import functional as F


class ImageAdapter(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        output_dim = int(input_dim/4)

        self.fc1 = nn.Linear(input_dim, output_dim, bias=False)
        self.fc2 = nn.Linear(output_dim, input_dim, bias=False)
        self.relu = nn.ReLU()

    def forward(self, image_features):
        
        identity = image_features

        out = self.fc1(image_features)
        out = self.relu(out) 
        out = self.fc2(out)

        out +=  identity

        return out


class CLIPAdapter(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.image_adapter = ImageAdapter(clip_model.visual.output_dim)

    def forward(self, image_features, text_features):
        adapted_image_features = self.image_adapter(image_features)
        adapted_image_features = adapted_image_features / adapted_image_features.norm(dim=-1, keepdim=True)

        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.to(torch.float32)
        logits = 100.0 * adapted_image_features @ torch.transpose(text_features, 1, 2)
        logits = torch.squeeze(logits,1)

        return logits