import jax
import jax.numpy as jnp
import numpy as npdat
from datasets import load_dataset, Dataset, Array3D, Features, load_from_disk, enable_progress_bars
import inox
from tqdm import trange, tqdm

from utils import *

CONFIG = {
    # Data
    # 'duplicate': 2,
    'corruption': 75,
    'img_shape': (64, 64, 3),
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}
        
config = MyDict(CONFIG)

def corrupt(rng, corruption, dataset: Dataset):
    def transform(row):
        x = np.asarray(row['x'])
        A = rng.bernoulli(p = 1 - corruption / 100, shape = config.img_shape[:2] + (1, ))
        y = np.array(A * x + (1 - A) * (128))
        return {
            'y': y, 'A': np.array(A)
        }
    
    types = {
        'y': Array3D(shape=config.img_shape, dtype='float32'),
        'A': Array3D(shape=config.img_shape[:2] + (1, ), dtype='bool')
    }

    return dataset.map(
        transform,
        remove_columns=dataset.column_names,
        features=Features(types),
        keep_in_memory=True,
        num_proc=1
    )

def corrupt_dataset():
    """
    Loads the dataset and 
    """

    print('! loading the dataset')
    dataset = load_dataset('/data/vision/___/scratch/___ht/celeba_64/', keep_in_memory=True)
    print('! dataset loaded')

    trainset = dataset['train']
    trainset = trainset.rename_column('image', 'x')

    seed = hash('celeba') % (1<<16)
    rng = inox.random.PRNG(seed)
    trainset_corrupted = corrupt(rng, config.corruption, trainset)

    trainset_corrupted.save_to_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask{config.corruption}')


import datetime

def corrupt_dataset_test():
    print(f'! loading the dataset {datetime.datetime.now()}')
    dataset = load_dataset('/data/vision/___/scratch/___ht/celeba_64/', keep_in_memory=True, num_proc = 64)
    print(f'! dataset loaded {datetime.datetime.now()}')

    trainset = dataset['train'].select(range(1 << 10))
    trainset = trainset.rename_column('image', 'x')

    seed = hash('celeba') % (1<<16)
    rng = inox.random.PRNG(seed)
    trainset_corrupted = corrupt(rng, config.corruption, trainset)

    trainset_corrupted.save_to_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask{config.corruption}_test')

def normalize_dataset():
    """
    Maps a dataset to [-2, 2] and sets the 
    dead pixels to 0 (not -2).
    """
    dataset = load_from_disk('/data/vision/___/scratch/___ht/celeba_64_mask50/', keep_in_memory=True)
    dataset.set_format('numpy')
    y_new = dataset['A'] * (dataset['y'] * 4 / 256 - 2) 
    A_new = dataset['A']
    
    def generator():
        for i in trange(len(y_new)):
            yield {'A': A_new[i], 'y': y_new[i]}
    
    ds = Dataset.from_generator(generator)
    
    save_path = Path('/data/vision/___/scratch/___ht/celeba_64_mask50_normalized/')
    save_path.mkdir(exist_ok=True, parents=True)
    ds.save_to_disk(save_path)

if __name__ == '__main__':

    # corrupt_dataset()
    # normalize_dataset()

    corrupt_dataset_test()

    # load_and_test()

    # corrupted_dataset = load_from_disk('/data/vision/___/scratch/___ht/celeba_64_mask50_test/')

    # print(f'! loading the dataset {datetime.datetime.now()}')
    # dataset = load_dataset('/data/vision/___/scratch/___ht/celeba_64/', keep_in_memory=True, num_proc = 64)
    # dataset = load_from_disk('/data/vision/___/scratch/___ht/celeba_64_test/')
    # print(f'! dataset loaded {datetime.datetime.now()}')

    # seed = hash('celeba') % (1<<16)
    # rng = inox.random.PRNG(seed)
    # corrupted_dataset = corrupt(rng, config.corruption, dataset)

    # creates samples to visualize
    # first_y = np.array(corrupted_dataset['y'][0], dtype = np.uint8)
    # first_A = 255 * np.array(corrupted_dataset['A'][0], dtype = np.uint8)

    # Image.fromarray(first_y).save('/data/vision/___/scratch/___ht/celeba_64_test/first_y.png')
    # Image.fromarray(first_A).save('/data/vision/___/scratch/___ht/celeba_64_test/first_A.png')

    # save the corrupted dataset
    # corrupted_dataset.save_to_disk('/data/vision/___/scratch/___ht/celeba_64_mask50_test/')

