import json
import os
import random

import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data.dataset import Dataset


class CC15M(Dataset):
    def __init__(
            self,
            json_path, 
            video_folder=None,
            resolution=512,
            enable_bucket=False,
        ):
        print(f"loading annotations from {json_path} ...")
        self.dataset = json.load(open(json_path, 'r'))
        self.length = len(self.dataset)
        print(f"data scale: {self.length}")
        
        self.enable_bucket = enable_bucket
        self.video_folder = video_folder

        resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
        self.pixel_transforms = transforms.Compose([
            transforms.Resize(resolution[0]),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
    
    def get_batch(self, idx):
        video_dict = self.dataset[idx]
        video_id, name = video_dict['file_path'], video_dict['text']

        if self.video_folder is None:
            video_dir = video_id
        else:
            video_dir = os.path.join(self.video_folder, video_id)

        pixel_values = Image.open(video_dir).convert("RGB")
        return pixel_values, name

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values, name = self.get_batch(idx)
                break
            except Exception as e:
                print(e)
                idx = random.randint(0, self.length-1)

        if not self.enable_bucket:
            pixel_values = self.pixel_transforms(pixel_values)
        else:
            pixel_values = np.array(pixel_values)

        sample = dict(pixel_values=pixel_values, text=name)
        return sample

if __name__ == "__main__":
    dataset = CC15M(
        csv_path="xxx",
        resolution=512,
    )
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
    for idx, batch in enumerate(dataloader):
        print(batch["pixel_values"].shape, len(batch["text"]))