# Generating Data to later train an unconditional model on those data

from priors import *
from datasets import Features, Array3D, load_from_disk

from utils import *

CONFIG = {
    # Data
    'corruption': 75,
    # Architecture
    'hid_channels': (128, 256, 384),
    'hid_blocks': (5, 5, 5),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {1: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'sde': {'a': 1e-3, 'b': 1e2},
    'heuristic': None,
    'discrete': 256,
    'maxiter': 1,
    # Training
    'epochs': 256,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 2e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.9999,
}

dir_name = ['checkpoints_1x_conditional', 'checkpoints_large_conditional'][1]

PATH = Path('/data/vision/___/scratch/___ht/' + dir_name)

lap = 16

jax.config.update('jax_threefry_partitionable', True)

mesh = jax.sharding.Mesh(jax.devices(), 'i')
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

class MyDict:
    def __init__(self, dict):
        self.dict = dict
    def __getattr__(self, key):
        return self.dict[key]

config = MyDict(CONFIG)

# RNG
seed = hash('Hosseintabar') % 2**16
rng = inox.random.PRNG(seed)

# SDE
sde = VESDE(**CONFIG.get('sde'))


# Data
dataset = load_from_disk(f'/data/vision/___/scratch/___ht/diffusion-priors/experiments/cifar/hf/cifar-mask-{config.corruption}')
dataset.set_format('numpy')

trainset_yA = dataset['train']

config = MyDict(CONFIG)

previous = load_module(PATH / f'checkpoint_{lap}.pkl')

def generate_conditional(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y_cond, A = batch['y'], batch['A']
        x = sample_conditional(model, y_cond, rng.split(), **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(32, 32, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

trainset = generate_conditional(
        model=previous,
        dataset=trainset_yA,
        rng=rng,
        batch_size=config.batch_size,
        shard=True,
        sampler=config.sampler,
        sde=sde,
        steps=config.discrete,
        maxiter=config.maxiter,
    )

dump_module(trainset, '/data/vision/___/scratch/___ht/generated_data.pkl')
