import os
import random
import clip
import scipy
import torch
import torch.nn.functional as F
import tqdm
from PIL import Image
from torchvision.datasets import CIFAR100
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)


class CLIPImageDataset(torch.utils.data.Dataset):
    def __init__(self, data,preprocess):
        self.data = data
        self.preprocess = preprocess

    def _transform_test(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            lambda image: image.convert("RGB"),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __getitem__(self, idx):
        c_data = self.data[idx]
        image = Image.open(c_data)
        image = self.preprocess(image)
        return {'image': image}

    def __len__(self):
        return len(self.data)



class CLIPCapDataset(torch.utils.data.Dataset):
    def __init__(self, data, human_score=None, prefix='A photo depicts'):
        self.data = data
        self.prefix = prefix
        if self.prefix[-1] != ' ':
            self.prefix += ' '
        self.human_score = human_score

    def __getitem__(self, idx):
        c_data = self.data[idx]
        c_data = clip.tokenize(self.prefix + c_data, truncate=True).squeeze()
        if self.human_score is not None:
            return {'caption': c_data, 'human_score': self.human_score[idx]}
        return {'caption': c_data}

    def __len__(self):
        return len(self.data)
