__author__ = "Anon"
__version__ = "0.1"

import os
from torchvision.datasets import VisionDataset
import numpy as np
from PIL import Image
import torch


class BMW10(VisionDataset):
    def __init__(self, root, image_set='train', transform=None):
        super(BMW10, self).__init__(root, transform=transform, target_transform=None)
        try:
            from scipy.io import loadmat
            self._loadmat = loadmat
        except ImportError:
            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
                               "pip install scipy")

        self.num_classes = 10

        image_dir = os.path.join(root, 'bmw10_ims')
        data = self._loadmat(os.path.join(root, 'bmw10_annos.mat'))
        if image_set == 'train':
            self.ids = data['train_indices'].squeeze()
        else:
            self.ids = data['test_indices'].squeeze()

        self.images = [os.path.join(image_dir, data['annos'].squeeze()[x-1][0][0]) for x in list(self.ids)]
        self.labels = [int(data['annos'].squeeze()[x-1][1])-1 for x in list(self.ids)]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = self.labels[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.tensor(target).long()

