from torch import nn
import torch
import clip


class Guider(nn.Module):
    def __init__(self, device):
        super(Guider, self).__init__()
        self.device = device

    def encode_image(self, i):
        pass

    def encode_text(self, t):
        pass


class CLIPGuilder(Guider):
    def __init__(self, name='ViT-B/32', device='cuda'):
        super().__init__(device)
        self.model, self.preprocess = clip.load(name, device=device)

    def encode_image(self, i):
        with torch.no_grad():
            image_features = self.model.encode_image(i)
        return image_features.float()

    def encode_text(self, t):
        with torch.no_grad():
            t = clip.tokenize(t)
            text_features = self.model.encode_text(t)
        return text_features


if __name__ == '__main__':
    clip_guilder = CLIPGuilder()

    image = torch.randn(2, 3, 224, 224)
    text = ["a diagram", "a dog", "a cat", "noise"]

    i = clip_guilder.image(image)
    print(i.shape)
    t = clip_guilder.text(text)
    print(t.shape)
    i = clip_guilder.image(image)
    print(i.shape)
    t = clip_guilder.text(text)
    print(t.shape)
