import os
import pickle
import sys
sys.path.append("..")

import numpy as np
from torchvision import transforms
import torch
from torch.utils import data
import torchvision
from PIL import Image
import scipy.io as sio

from config import opt


class SVHN(object):
    def __init__(self, input_size = 32, transform=None, n_classes=None, partition=None):
        self.n_classes = 10
        train_transform = transforms.Compose([
            # transforms.RandomCrop(input_size, padding=4),
            # transforms.RandomHorizontalFlip(),
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.44, 0.44, 0.47), (0.20, 0.20, 0.20)
            ),
            # transforms.Normalize(
            #     (.48,.07,.02,), (.43,.77,.87,)
            # ),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.44, 0.44, 0.47), (0.20, 0.20, 0.20)
            ),
        ])
        # test_transform = transforms.Compose([
        #     transforms.Resize(input_size),
        #     transforms.ToTensor(),
        # ])
        self.train_dataset = torchvision.datasets.SVHN(
            root = os.path.join(opt.data_dir, 'datasets', 'svhn'),
            split = 'train',
            download = True,
            transform = train_transform
        )
        self.test_dataset = torchvision.datasets.SVHN(
            root = os.path.join(opt.data_dir, 'datasets', 'svhn'),
            split = 'test',
            download = True,
            transform = test_transform
        )
        if transform:
            self.dataset = torchvision.datasets.SVHN(
                root = os.path.join(opt.data_dir, 'datasets', 'svhn'),
                split = 'train',
                download = True,
                transform = transform
            )

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 128, #64
            shuffle = True,
            num_workers = 4,
            drop_last = True
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )

