import os
import sys
sys.path.append('..')

import torch
from torch.utils.data import Dataset
import numpy as np
# import cv2
from PIL import Image
from torchvision import transforms
import torchvision.datasets as datasets

from config import opt

def TinyImageNetDataset(root, download, transform, split='train'):
    return datasets.ImageFolder(os.path.join(root, 'tiny-imagenet-200', split), transform)

class TinyImageNet(object):
    def __init__(self, input_size=32, transform=None, n_classes=None, partition=None):
        self.input_size = input_size
        self.root = opt.data_dir+'datasets'

        train_transform = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.RandomCrop(self.input_size, padding=4),
            transforms.RandomHorizontalFlip(),
            # transforms.Resize(input_size),
            # transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
            ),
        ])

        test_transform = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
            ),
        ])

        self.train_dataset = datasets.ImageFolder(os.path.join(self.root, 'tiny-imagenet-200','train'), train_transform)
        self.val_dataset = datasets.ImageFolder(os.path.join(self.root, 'tiny-imagenet-200', 'val'), test_transform)
        self.test_dataset = datasets.ImageFolder(os.path.join(self.root, 'tiny-imagenet-200', 'test'), test_transform)
        if transform:
            self.dataset = datasets.ImageFolder(os.path.join(self.root, 'tiny-imagenet-200', 'train'), transform)


    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 128,#128, #64
            shuffle = True,
            num_workers = 4,
            drop_last = True
        )

    def val_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )

if __name__ == '__main__':
    dataset = TinyImageNet(input_size=32)
    dataloader = dataset.train_dataloader()
    for i, (img, targets) in enumerate(dataloader):
        pass
        # print(img.shape, targets.shape)
    print(i)
