from transformers import CLIPProcessor
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from torch.utils.data import DataLoader, Dataset
import cv2  

class TextVideoDataset(Dataset):
    def __init__(self, textual_queries, video_paths, labels, clip_processor, frame_transform):
        self.textual_queries = textual_queries
        self.video_paths = video_paths
        self.labels = labels
        self.clip_processor = clip_processor
        self.frame_transform = frame_transform

    def __len__(self):
        return len(self.textual_queries)

    def __getitem__(self, idx):
        text = self.textual_queries[idx]
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        text_inputs = self.clip_processor(text=text, return_tensors="pt").input_ids.squeeze(0)  
        cap = cv2.VideoCapture(video_path)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  
            frame = self.frame_transform(frame)  
        cap.release()

        return text_inputs, frame, label

frame_transform = Compose([
    ToTensor(),
    Resize((224, 224)),  
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
dataset = TextVideoDataset(textual_queries, video_paths, labels, clip_processor, frame_transform)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)  
