import os
import pickle
import sys
sys.path.append("..")

import numpy as np
from torchvision import transforms
import torch
from torch.utils import data
import torchvision
from PIL import Image

from config import opt


class NoiseDataset(data.Dataset):
    def __init__(self, input_size=32, ds_size=50000, transform=None):
        noise_file_path = os.path.join(opt.work_dir, 'code/data/checkpoints', 'noise.pkl')
        if os.path.exists(noise_file_path):
            print('Loading saved noise file...')
            with open(noise_file_path, 'rb') as fp:
                self.items = pickle.load(fp)
        else:
            self.items = []
            for _ in range(ds_size):
                noise_arr = (np.random.rand(input_size, input_size, 3)*255).astype(np.uint8)
                noise_img = Image.fromarray(noise_arr, mode='RGB')
                item = (noise_img, 0)
                self.items.append(item)
            with open(noise_file_path, 'wb') as fp:
                pickle.dump(self.items, fp)
        if transform:
            self.transform = transform
        else:
            self.transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
            ),
        ])


    def __getitem__(self, index):
        item = self.items[index]
        # print(item[0].size)
        data = self.transform(item[0])
        # print(data.shape)
        target = item[1]
        return data, target

    def __len__(self):
        return len(self.items)


class Noise(object):
    def __init__(self, input_size = 32, transform=None, n_classes=None):
        self.n_classes = 10
        train_transform = transforms.Compose([
            transforms.RandomCrop(input_size, padding=4),
            transforms.RandomHorizontalFlip(),
            # transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
            ),
            # transforms.Normalize(
            #     (.48,.07,.02,), (.43,.77,.87,)
            # ),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
            ),
        ])
        self.train_dataset = NoiseDataset(
            transform = train_transform
        )
        self.test_dataset = NoiseDataset(
            transform = test_transform
        )
        if transform:
            self.dataset = NoiseDataset(
            transform = transform
        )

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 128, #64
            shuffle = True,
            num_workers = 4,
            drop_last = True
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )