import torch
import torch.nn as nn

from promptrl.clipcap.model import MLP

class ClipCapPrompt(nn.Module):
    # adapted code for this codebase
    def __init__(self, prefix_length, prefix_size=512):
        super().__init__()
        self.prefix_length = prefix_length
        self.prefix_size = prefix_size
        self.gpt_embedding_size = 768
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
        else:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
    def forward(self, prefix):
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        prompt_mask = torch.full((prefix.shape[0], prefix.shape[1]), 1, device=prefix.device)
        return {
            'input_embeds': prefix_projections,
            'attention_mask': prompt_mask
        }
