import os
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image


class PascalVoc07Dataset(Dataset):
    CLASSES = [
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
        'dog',
        'horse',
        'motorbike',
        'person',
        'pottedplant',
        'sheep',
        'sofa',
        'train',
        'tvmonitor'
    ]
    SPLITS = [
        'train',
        'trainval',
        'val',
        'test'
    ]

    def __init__(self, root_dir='../datasets/VOCdevkit/VOC2007', 
                 split='train', transform=None, data_size=-1, mode='train'):
        assert(split in ['train', 'trainval', 'val', 'test'])
        assert(mode in ['train', 'test'])

        self.split = split  # train, trainval, val, test
        self.mode = mode
        
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'JPEGImages')
        self.splits_dir = os.path.join(root_dir, 'ImageSets/Main')
        self.transform = transform
        self.id2label = {i: self.CLASSES[i] for i in range(len(self.CLASSES))}
        self.label2id = {self.CLASSES[i]: i for i in range(len(self.CLASSES))}
        self.data_size = data_size
        
        if mode == 'train':
            # Load data
            # Load image ids
            with open(os.path.join(self.splits_dir, f'{split}.txt'), 'rt') as input_file:
                self.image_ids = [line.strip() for line in input_file.readlines()]
                self.labels = torch.zeros(len(self.image_ids), len(self.CLASSES))
            # Load labels
            for c_i, cls in enumerate(self.CLASSES):
                with open(os.path.join(self.splits_dir, f'{cls}_{split}.txt'), 'rt') as input_file:
                    for i, line in enumerate(input_file.readlines()):
                        if data_size != -1 and i >= data_size:
                            break
                        image_id, label = line.strip().split()
                        if label == '1':
                            image_idx = self.image_ids.index(image_id)
                            self.labels[image_idx, c_i] = 1
        else: # test
            self.image_ids = []
            self.labels = []
            for c_i, cls in enumerate(self.CLASSES):
                with open(os.path.join(self.splits_dir, f'{cls}_{split}.txt'), 'rt') as input_file:
                    for i, line in enumerate(input_file.readlines()):
                        if data_size != -1 and i >= data_size:
                            break
                        image_id, label = line.strip().split()
                        self.image_ids.append(image_id)
                        self.labels.append((c_i + 1) * int(label)) # 3 = correct for class 3, -6 = incorrect for class 6
        
        print(f'Loaded VOC {split} {len(self.image_ids)} examples')
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        label = self.labels[idx]
        image_path = os.path.join(self.image_dir, f'{image_id}.jpg')
        if self.transform:
            image = Image.open(image_path)
            image = self.transform(image)
        else:
            image = image_path
        return image, label
    

if __name__ == '__main__':
    dataset = PascalVoc07Dataset()
