from __future__ import print_function
import os

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# from .image_transforms import Low_pass_filter
from PIL import Image
import pdb


class CIFAR10_twopass(datasets.CIFAR10):

    def __init__(self,root,train,fine_transform=None, target_transform=None,
                                      coarse_transform=None,):
        super(CIFAR10_twopass,self).__init__(root=root,
                                        train = train)
        self.fine_transform = fine_transform
        self.coarse_transform = coarse_transform
        self.target_transform = target_transform

    def __getitem__(self,index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.fine_transform is not None:
            fine_img = self.fine_transform(img)
        else:
            fine_img = img

        if self.coarse_transform is not None:
            coarse_img = self.coarse_transform(img)
        else:
            coarse_img = img

        if self.target_transform is not None:
            target = self.target_transform(target)

        return fine_img, coarse_img, target


    def __len__(self):
        return len(self.targets)
