import PIL.Image
import torch, torchvision
import os
import random
import PIL
import pandas as pd
import scipy.io
import numpy as np



class ClassifyDataSet(torch.utils.data.Dataset):
    def __init__(self,transform_img,data_csv):
        super(ClassifyDataSet,self).__init__()
        self.transform_img = transform_img
        self.data_csv = data_csv
        self.data = pd.read_csv(self.data_csv)
        self.data_image = self.data['y']
        self.data_label = self.data['label']
        self.true = self.data['x']

    def __len__(self):
        length = len(self.data['y'])
        return length

    def __getitem__(self, idx):
        # y_data = scipy.io.loadmat(self.data_image[idx])['y']
        # y_tensor = torch.from_numpy(y_data).float()

        y_img = PIL.Image.open(self.data_image[idx]).convert('RGB')
        y_tensor = self.transform_img(y_img).float()

        lab = int(self.data_label[idx])
        labs = torch.tensor(lab).to(torch.long)
        x_img = PIL.Image.open(self.true[idx]).convert('RGB')
        x_tensor = self.transform_img(x_img).float()

        return y_tensor,labs,x_tensor

