__author__ = "Anon"
__version__ = "0.1"

import os
from torchvision.datasets import VisionDataset
import numpy as np
from PIL import Image
import torch
import pandas as pd


class Pets(VisionDataset):
    def __init__(self, root, num_classes=37, image_set='trainval', transform=None, class_id_column=1):
        super(Pets, self).__init__(root, transform=transform, target_transform=None)

        self.num_classes = num_classes
        print(root)
        self.image_dir = os.path.join(root, 'images')
        self.data = pd.read_csv(os.path.join(root, 'annotations', '{}.txt'.format(image_set)), usecols=[0, class_id_column], names=['image','class_id'], header=None, delim_whitespace=True)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        data = self.data.iloc[index]
        img = Image.open(os.path.join(self.image_dir, data['image']+'.jpg')).convert('RGB')
        target = data['class_id']-1
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.tensor(target).long()
