#!/usr/bin/env python

from datasets import Array3D, Features, load_dataset, Dataset, load_from_disk
from dawgz import after, job, schedule


from imagecorruptions import corrupt

# isort: split
from utils import *


@job(cpus=4, ram='64GB', time='06:00:00')
def download():
    load_dataset('cifar10', cache_dir=PATH / 'hf')


@after(download)
@job(cpus=4, ram='64GB', time='06:00:00')
def corrupt_dataset(corruption: int = 75):
    PATH = Path('/data/vision/___/scratch/___ht/cifar_dir')

    def transform(row):
        x = from_pil(row['img'])
        A = np.random.uniform(size=(32, 32, 1)) > corruption / 100
        y = np.random.normal(loc=A * x, scale=1e-3)

        return {'A': A, 'y': y}

    types = {
        'A': Array3D(shape=(32, 32, 1), dtype='bool'),
        'y': Array3D(shape=(32, 32, 3), dtype='float32'),
    }

    dataset = load_dataset('cifar10', cache_dir=PATH / 'hf')
    dataset = dataset.map(
        transform,
        features=Features(types),
        remove_columns=['img', 'label'],
        keep_in_memory=True,
        num_proc=4,
    )

    dataset.save_to_disk(PATH / f'hf/cifar-mask-{corruption}')

counter = 0

def save_to_png():
    path_from = Path('/data/vision/___/scratch/___ht/cifar_dir')
    path_to = Path('/data/vision/___/scratch/___ht/cifar_dir/datasets_for_eval/ref')
    path_to.mkdir(parents = True, exist_ok = True)
    
    def transform(row):
        global counter
        row['img'].save(path_to / f'{counter}.png')
        counter += 1
        return None
    
    dataset = load_dataset('cifar10', cache_dir=path_from / 'hf')

    dataset = dataset['train']

    dataset = dataset.map(
        transform,
        keep_in_memory=True,
        num_proc=1,
    )


def gaussian_kernel(kernel_size: int, sigma: float):
    ax = np.linspace(-(kernel_size - 1) / 2., (kernel_size - 1) / 2., kernel_size)
    xx, yy = np.meshgrid(ax, ax)
    kernel = np.exp(-0.5 * (np.square(xx) + np.square(yy)) / np.square(sigma))
    return kernel / np.sum(kernel)


def corrupt_blur(severity: float, kernel_size: int):
    PATH = Path('/data/vision/___/scratch/___ht/cifar_dir')

    kernel = gaussian_kernel(kernel_size, severity)

    def transform(row):
        x = row['img']
        x = (from_pil(x) + 2) * (256 / 4)
        x = x.astype('uint8')
        # y = corrupt(x, corruption_name = 'gaussian_blur', severity = severity)
        y = jnp.stack([jax.scipy.signal.convolve2d(x[:,:,i], kernel, mode='same', boundary='fill', fillvalue=0) for i in range(3)], axis=-1)
        y = (y.astype('float32') * 4 / 256) - 2
        return {'y': np.array(y)}

    types = {
        'y': Array3D(shape=(32, 32, 3), dtype='float32'),
    }

    dataset = load_dataset('cifar10', cache_dir=PATH / 'hf')

    dataset = dataset.map(
        transform,
        features=Features(types),
        remove_columns=['img', 'label'],
        keep_in_memory=True,
        num_proc=1,
    )

    dataset.save_to_disk(PATH / f'hf/cifar-mask-gaussian-blur-s{severity}-k{kernel_size}')

class Counter:
    def __init__(self):
        self.cnt = 0
    def inc(self):
        self.cnt += 1
    def get(self):
        return self.cnt

def save_corrupted_to_png():
    dataset = load_from_disk('/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-75')
    dataset.set_format('numpy')

    trainset_yA = dataset['train']
    testset_yA = dataset['test']
    
    counter = Counter()

    def save_transform(row):
        to_pil(row['y']).save(f'/data/vision/___/scratch/___ht/cifar_dir/datasets_for_eval/masked75dataset/{counter.get()}.png')
        counter.inc()

    trainset_yA.map(
        save_transform
    )

if __name__ == '__main__':
    corrupt_dataset(90)
    # save_to_png()
    # corrupt_blur(severity=2, kernel_size=5)
    # save_corrupted_to_png()
