import torch
import torch.nn as nn
import open_clip

class CLIPFeatureExtractor(nn.Module):
    """
    A class that extracts features from images and text using a pre-trained CLIP model.
    """
    def __init__(self, clip_model_name="ViT-B-32", pretrained="openai"):
        """
        Initializes the CLIP model and tokenizer.

        Args:
            clip_model_name (str): The name of the CLIP model to use (default is "ViT-B-32").
            pretrained (str): The pre-trained weights to load (default is "openai").
        """
        super().__init__()
        # Load the CLIP model and tokenizer
        self.model, _, _ = open_clip.create_model_and_transforms(clip_model_name, pretrained=pretrained)
        self.model.eval()  # Set model to evaluation mode
        for param in self.model.parameters():
            param.requires_grad = False  # Freeze the model parameters
        self.visual_encoder = self.model.visual  # Visual encoder (for images)
        self.text_tokenizer = open_clip.get_tokenizer(clip_model_name)  # Tokenizer for text

    def encode_image(self, images):
        """
        Encodes a batch of images into feature vectors using the CLIP visual encoder.

        Args:
            images (Tensor): A batch of images (Tensor format) to encode.

        Returns:
            Tensor: Encoded feature vectors for the images.
        """
        with torch.no_grad():  # Disable gradient computation for inference
            return self.visual_encoder(images)  # Extract features from the images

    def encode_text(self, texts):
        """
        Encodes a batch of texts into feature vectors using the CLIP text encoder.

        Args:
            texts (list of str): A list of text strings to encode.

        Returns:
            Tensor: Encoded feature vectors for the texts.
        """
        with torch.no_grad():  # Disable gradient computation for inference
            tokens = self.text_tokenizer(texts).to(next(self.model.parameters()).device)  # Tokenize and move to the correct device
            return self.model.encode_text(tokens)  # Extract features from the texts
