
import torch.nn as nn
from typing import List, Tuple, Union
from transformers import CLIPVisionModel

class CLIPWithProject(nn.Module):
    def __init__(self, version, latent_dims, frozen=False):
        super().__init__()
        self.clip_model = CLIPVisionModel.from_pretrained(version)
        self.project_layer = nn.ModuleList([nn.Linear(768, ld, bias=False) for ld in latent_dims])
        for n, p in self.named_parameters():
            if n.startswith('clip_model.vision_model.post_layernorm'):
                p.requires_grad = False
        if frozen:
            self.clip_model.requires_grad_(False)

    def forward(self, imgs, i=0):
        clip_output = self.clip_model(imgs, output_hidden_states=True)
        return self.project_layer[i](clip_output["hidden_states"][-1])
    

class CLIPWithProject2(nn.Module):
    
    def __init__(self, latent_dim=1536, frozen=False):
        super().__init__()
        self.clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.project_layer = nn.Linear(768, latent_dim, bias=False)
        for n, p in self.named_parameters():
            if n.startswith('clip_model.vision_model.post_layernorm'):
                p.requires_grad = False
        if frozen:
            self.clip_model.requires_grad_(False)
        
    def forward(self, imgs):
        clip_output = self.clip_model(imgs, output_hidden_states=True)
        return self.project_layer(clip_output["hidden_states"][-2])
    

def get_preprocess_module(name, clip_version="openai/clip-vit-base-patch32", dim: Union[int, List] = 1536):
    if name == "CLIP":
        print("using clip as preprocess model")
        return CLIPWithProject(clip_version, [dim] if isinstance(dim, int) else dim, frozen=False)
    elif name == "CLIPfrozen2":
        print("using frozen clip 2 as preprocess model")
        return CLIPWithProject2(1536, frozen=True)
    else:
        raise NotImplementedError()
