import os
import torch
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class OxfordPetDataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.transform = transform
        self.root = root
        self.image_dir = os.path.join(root, 'images')
        self.split = split
        self.data = self.load_annotations()
        self.classes = self.get_class_names(os.path.join(root, 'annotations', 'name.txt'))
        self.targets = [int(image.split()[1])-1 for image in self.data]

    def load_annotations(self):
        if self.split == 'train':
            annotation_file = os.path.join(self.root, 'annotations', 'trainval.txt')
        else:
            annotation_file = os.path.join(self.root, 'annotations', 'test.txt')
        with open(annotation_file, 'r') as f:
            lines = f.readlines()

        images = [line.strip() for line in lines]

        return images

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_file = self.data[index]
        image = Image.open(os.path.join(self.root, 'images', image_file.split()[0] + '.jpg')).convert('RGB')
        label = self.targets[index]

        if self.transform:
            image = self.transform(image)

        return image, label

    def get_class_names(self, name_dir):
        with open(name_dir, 'r') as file:
            class_names = [line.strip() for line in file.readlines()]
        return class_names
