import os
import numpy as np
from .base_dataset import BaseDataset
from os import PathLike
import os.path as osp
from PIL import Image
from .transforms import Compose
def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path
    
class TUBerlin(BaseDataset):
    """The TUBerlin dataset for classification, just load the 2-d version.

    """
    def __init__(self,
                 root,
                 transforms,
                 target_transforms=None,
                 split_num=60,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if transforms is None:
            transforms = self.DEFAULT_TRNASFORMS
        self.transforms = transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations(split_num)
        
    def load_annotations(self, split_num=60):
        classes = []
        data_infos = []
        root = os.path.join(self.root, 'png')
        for dir_name in os.listdir(root):
            if dir_name[0] == '.':
                continue
            dir_path = os.path.join(root, dir_name)
            filenames = [name for name in os.listdir(dir_path) if name.endswith('png')]
            
            # split_num = 60
            if self.test_mode:
                filenames = filenames[split_num:]
            else:
                filenames = filenames[:split_num]
            for name in filenames:
                data_infos.append({
                        'img':os.path.join(dir_path, name),
                        'gt_label':len(classes)
                    })
            classes.append(dir_name)
        self.CLASSES = classes

        return data_infos

    def __getitem__(self, idx):
        img, label = self.data_infos[idx]['img'], self.data_infos[idx]['gt_label']

        if self.transforms is not None:
            img = np.array(Image.open(img))
            # print('ori np', img.shape, img.sum())
            if len(img.shape) == 2:
                img = Image.fromarray(np.tile(np.expand_dims(img, -1), (1,1,3)).astype(np.uint8))
            elif len(img.shape) == 3 and img.shape[2] == 1:
                img = Image.fromarray(np.tile(img, (1,1,3)).astype(np.uint8))
            
            img = self.transforms(img)
        if self.target_transforms is not None:
            label = self.target_transforms(label)
            
        return img, label
        





