import os 
import json

from torch.utils.data import Dataset
from utils import read_video

class VidHalDataset(Dataset):
    def __init__(self, data_path, video_root, vis_processor, num_frames, load_video=True, sample="middle", frame_indices_path=None) -> None:
        super().__init__()

        with open(data_path, "r") as f:
            self.examples = json.load(f)
        
        self.video_root = video_root
        self.num_frames = num_frames
        self.vis_processor = vis_processor
        self.load_video = load_video
        self.sample = sample  # Sampling method

        # Load frame indices if provided
        if frame_indices_path:
            with open(frame_indices_path, "r") as f:
                self.frame_indices = json.load(f)
        else:
            self.frame_indices = None
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, index):
        example = self.examples[index]
        video_name, captions, aspect = example["video"], example["captions"], example["aspect"]
        video_path = os.path.join(self.video_root, f"{video_name}.mp4")

        if self.load_video:
           video, _, _ = read_video(
            video_path=video_path, 
            num_frames=self.num_frames, 
            sample=self.sample, 
            frame_indices=self.frame_indices[video_name] if self.frame_indices is not None else None
        )
        else:
            video = None
        
        if video is not None and self.vis_processor is not None:
            video = self.vis_processor(video)

        return {
            "video" : video, "video_id" : video_name, "video_path" : video_path,
            "captions" : captions, "aspect" : aspect
        }
