import numpy as np
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class STL10(DatasetBase):
    """STL-10 dataset.

    Description:
    - 10 classes: airplane, bird, car, cat, deer, dog, horse,
    monkey, ship, truck.
    - Images are 96x96 pixels, color.
    - 500 training images per class, 800 test images per class.
    - 100,000 unlabeled images for unsupervised learning.

    Reference:
        - Coates et al. An Analysis of Single Layer Networks in
        Unsupervised Feature Learning. AISTATS 2011.
    """

    dataset_dir = "stl10"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        train_dir = osp.join(self.dataset_dir, "train")
        test_dir = osp.join(self.dataset_dir, "test")
        unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
        fold_file = osp.join(
            self.dataset_dir, "stl10_binary", "fold_indices.txt"
        )

        # Only use the first five splits
        assert 0 <= cfg.DATASET.STL10_FOLD <= 4

        train_x = self._read_data_train(
            train_dir, cfg.DATASET.STL10_FOLD, fold_file
        )
        train_u = self._read_data_all(unlabeled_dir)
        test = self._read_data_all(test_dir)

        if cfg.DATASET.ALL_AS_UNLABELED:
            train_u = train_u + train_x

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data_train(self, data_dir, fold, fold_file):
        imnames = listdir_nohidden(data_dir)
        imnames.sort()
        items = []

        list_idx = list(range(len(imnames)))
        if fold >= 0:
            with open(fold_file, "r") as f:
                str_idx = f.read().splitlines()[fold]
                list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")

        for i in list_idx:
            imname = imnames[i]
            impath = osp.join(data_dir, imname)
            label = osp.splitext(imname)[0].split("_")[1]
            label = int(label)
            item = Datum(impath=impath, label=label)
            items.append(item)

        return items

    def _read_data_all(self, data_dir):
        imnames = listdir_nohidden(data_dir)
        items = []

        for imname in imnames:
            impath = osp.join(data_dir, imname)
            label = osp.splitext(imname)[0].split("_")[1]
            if label == "none":
                label = -1
            else:
                label = int(label)
            item = Datum(impath=impath, label=label)
            items.append(item)

        return items
