import PIL
import numpy as np
import torchvision
import torch
import random
from PIL import Image


class STL10_UNLABELED(torchvision.datasets.STL10):
    def __init__(self, transform, root='./Data', ):
        torchvision.datasets.STL10.__init__(self, root=root, split="unlabeled", download=True)
        self.transform = transform

    def __getitem__(self, index: int):
        img, target = self.data[index], -1.0
        img = np.transpose(img, (1, 2, 0))
        img1 = Image.fromarray(img)
        img2 = Image.fromarray(img)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return [img1, img2], target, index


class STL10_Test(torchvision.datasets.STL10):
    def __init__(self, transform, root='./Data', ):
        torchvision.datasets.STL10.__init__(self, root=root, split="test", download=True)
        self.transform = transform

    def __getitem__(self, index: int):
        img, target = self.data[index], int(self.labels[index])
        img = np.transpose(img, (1, 2, 0))
        img1 = Image.fromarray(img)
        img2 = Image.fromarray(img)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return [img1, img2], target, index
