import torch
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageNet

class CustomVOCDataset(torch.utils.data.Dataset):
    def __init__(self, root, year="2012", image_set="train", transform=None):
        self.voc_dataset = VOCSegmentation(root=root, year=year, image_set=image_set, download=True, transform=None)
        self.transform = transform

    def __len__(self):
        return len(self.voc_dataset)

    def __getitem__(self, idx):
        img, mask = self.voc_dataset[idx]
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
            mask = (mask > 0).float()
        return img, mask

def get_pascal(dataset_path="../data/pascal"):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    custom_voc_dataset = CustomVOCDataset(root=dataset_path, year="2012", image_set="val", transform=transform)
    data_loader = DataLoader(custom_voc_dataset, batch_size=20, shuffle=False)
    return data_loader

def get_imagenet(args):
    dataset_path = '../data/ImageNet' 
    transforms = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
    ])

    validation_dataset = ImageNet(root=dataset_path, split='val', transform=transforms)
    partition_to_take = int(len(validation_dataset)*.5)
    data = torch.utils.data.random_split(validation_dataset, [partition_to_take, len(validation_dataset)-partition_to_take], generator=torch.Generator().manual_seed(1))
    validation_loader = DataLoader(data, batch_size=args.batch_size, shuffle=False)

    return validation_loader