import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T


class CustomDataset(Dataset):
    def __init__(self, txt_file, root_dir, transform=None):
        self.image_list = []
        self.id_list = []
        self.root_dir = root_dir
        self.transform = transform
        with open(txt_file, 'r') as f:
            line = f.readline()
            # self.datas = f.readlines()
            while line:
                img_name = line.split()[0]
                label = int(line.split()[1])
                self.image_list.append(img_name)
                self.id_list.append(label)
                line = f.readline()
        
    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label = self.id_list[idx]
        img_name = os.path.join(self.root_dir, img_name)
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image, label


def DESINet(target_size=224):
    path = "./DESINet/"
    train_txt = "./DESINet_train_100.txt"

    train_transforms = T.Compose([
        T.RandomResizedCrop(target_size, scale=(0.2, 1.0), interpolation=3),
        T.RandomHorizontalFlip(0.5),
        T.ToTensor(),
        T.Normalize(mean=[0.1448, 0.1373, 0.1340], std=[0.1058, 0.0987, 0.0920])
    ])

    train_data = CustomDataset(txt_file=train_txt, root_dir=path, transform=train_transforms)

    return train_data
