from glob import glob 
from . import constants as cs
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from os.path import join as osj
from PIL import Image

DTD_CLASS_NAMES = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked',
                   'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid',
                   'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed',
                   'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared',
                   'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled',
                   'woven', 'wrinkled', 'zigzagged']

class DTD(Dataset):
    def __init__(self, root, split="1", train=False,  transform=lambda x: x):
        super().__init__()
        train_path = osj(root, f"labels/train{split}.txt")
        val_path = osj(root, f"labels/val{split}.txt")
        test_path = osj(root, f"labels/test{split}.txt")
        if train:
            self.ims = open(train_path).readlines() + \
                            open(val_path).readlines()
        else:
            self.ims = open(test_path).readlines()
        
        self.full_ims = [osj(root, "images", x) for x in self.ims]
        
        pth = osj(root, f"labels/classes.txt")
        self.c_to_t = {x.strip(): i for i, x in enumerate(DTD_CLASS_NAMES)}
        self.classes = range(47)

        self.transform = transform
        self.labels = [self.c_to_t[x.split("/")[0]] for x in self.ims]

    def __getitem__(self, index):
        im = Image.open(self.full_ims[index].strip())
        im = self.transform(im)
        return im, self.labels[index]

    def __len__(self):
        return len(self.ims)

