r"""Generates dataset for evaluation"""

import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset, Dataset, Array3D, Features, load_from_disk
import inox
from tqdm import trange, tqdm

import wandb
from pathlib import Path
import re

from utils import *

TEST_MODE = True

NUM_DATA = 50_000 if not TEST_MODE else 128

CONFIG = {
    # Data
    # 'duplicate': 2,
    'corruption': 50,
    'img_shape': (64, 64, 3),
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64,
    'batch_size': 256 if not TEST_MODE else 8,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}
config = MyDict(CONFIG)


class Counter:
    def __init__(self):
        self.cnt = 0
    def inc(self):
        self.cnt += 1
    def get(self):
        return self.cnt

def _generate_conditional_and_save(model, dataset, rng, batch_size, sde, save_dir, **kwargs):
    counter = Counter()

    def transform(batch):
        if counter.get() == NUM_DATA:
            return
        y_cond = np.asarray(batch['y'])
        A_cond = np.asarray(batch['A'])

        # this one when working with diffEM
        x = sample_conditional(
                model,
                y_cond,
                rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter
                )

        # use this one when working with moment matching
        # x = sample(
        #         model=model,
        #         y=y_cond,
        #         A=A_cond,
        #         key=rng.split(),
        #         shard=True,
        #         sampler=config.sampler,
        #         steps=config.discrete,
        #         maxiter=config.maxiter
        #         )

        x = np.asarray(x)

        for each in x:
            to_pil(each).save(save_dir / f'{counter.get()}.png')
            counter.inc()
            if counter.get() == NUM_DATA:
                break


    return dataset.map(
        transform,
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
        num_proc = 1
    )

def generate_dataset_conditional(checkpoint_path: Path, save_dir: Path, corruption: int):
    save_dir.mkdir(parents=True, exist_ok=True)
    seed = hash(('celebA', str(checkpoint_path))) % (1<<16)
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    if TEST_MODE:
        dataset = load_from_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask{corruption}_test/', keep_in_memory=True)
    else:
        dataset = load_from_disk(f'/data/vision/___/scratch/___ht/celeba_64_mask{corruption}/', keep_in_memory=True)

    trainset = dataset
    trainset.set_format(type='numpy', columns=['y'])

    def normalize_transform(batch):
        y = batch['y']
        A = batch['A']
        y = y * 4 / 256 - 2
        # y = A * y
        return {'y': y}

    trainset = trainset.map(
        normalize_transform,
        batched=True,
        batch_size=config.batch_size,
        desc="Normalizing",
        num_proc=4,  # Use multiple processes for faster mapping
        keep_in_memory=True
    )

    breakpoint()

    model = load_module(checkpoint_path)
    save_dir.mkdir(parents = True, exist_ok = True)

    _generate_conditional_and_save(model, trainset, rng, config.batch_size, sde, save_dir)

def generate_clear_dataset():
    """
    Used for ref
    """
    seed = hash(('celebA', str('checkpoint_path'))) % (1<<16)
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    dataset = load_from_disk('/data/vision/___/scratch/___ht/celeba_64_mask0/', keep_in_memory=True)
    
    trainset = dataset
    trainset.set_format(type='numpy', columns=['y'])

    def normalize_transform(row):
        x = row['y']
        x = x * 4 / 256 - 2
        return {'y': x}

    trainset = trainset.map(normalize_transform)

    counter = Counter()

    def save_transform(row):
        if counter.get() == 50_000:
            return
        to_pil(row['y']).save(f'/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/ref/{counter.get()}.png')
        counter.inc()

    trainset.map(
        save_transform
    )

def _generate_unconditional_and_save(model, rng, batch_size, sde, save_dir, **kwargs):
    # TODO: This whole thing is correct but is kinda stupid since
    # it takes the dataset and doesn't even use it, fix this later
    counter = Counter()
    
    while counter.get() < NUM_DATA:
        x = sample(
                model=model,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
                y=None,
                A=None,
                )
        
        x = np.asarray(x)
        
        for each in x:
            to_pil(each).save(save_dir / f'{counter.get()}.png')
            counter.inc()
            if counter.get() == NUM_DATA:
                break

def generate_dataset_unconditional(checkpoint_path: str|Path, save_dir: str|Path, corruption: int):
    seed = hash(('celebA', str(checkpoint_path))) % (1<<16)
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    _generate_unconditional_and_save(
        model=load_module(checkpoint_path),
        rng=rng,
        batch_size=config.batch_size,
        sde=sde,
        save_dir=Path(save_dir)
    )

if __name__ == "__main__":
    # generate_clear_dataset()
    
    conditional = True
    dir_name = 'unconditional' if not conditional else 'conditional'
    indices = [19]
    indices.reverse()
    for i in indices:
        if conditional:
            # generate_dataset_conditional(
            #     checkpoint_path = Path(f'/data/vision/___/scratch/___ht/celeba_dir/checkpoints_mask50/checkpoint_19.pkl'),
            #     save_dir = Path(f'/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/conditional/mask50/mask50_19'),
            #     corruption = 50
            #     )

            generate_dataset_conditional(
                checkpoint_path = Path(f'/data/vision/___/scratch/___ht/celeba_dir/checkpoints_mask75/checkpoint_1.pkl'),
                save_dir = Path(f'/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/conditional/mask75/mask75_1_test'),
                corruption = 75
                )

            # generate_dataset_conditional(
            #     checkpoint_path = Path(f'/data/vision/___/scratch/___ht/celeba_dir/checkpoints_mask75/checkpoint_16.pkl'),
            #     save_dir = Path(f'/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/conditional/mask75/mask75_16_test'),
            #     corruption = 75
            #     )

            # Let's try and see if it works on moment matching?

            # generate_dataset_conditional(
            #     checkpoint_path = Path(f'/data/vision/___/scratch/___ht/celeba_dir/em-mmps_mask50/checkpoints/checkpoint_8.pkl'),
            #     save_dir = Path(f'/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/{dir_name}/em_mmps_mask50_checkpoint_8'),
            #     corruption = 50
            #     )

            # generate_dataset_conditional(
            #     checkpoint_path = Path(f'/data/vision/___/scratch/___ht/celeba_dir/em-mmps_mask75/checkpoints/checkpoint_8.pkl'),
            #     save_dir = Path(f'/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/{dir_name}/em_mmps_mask75_checkpoint_8'),
            #     corruption = 75
            #     )
            
        else:
            pass
            # moment matching mask 50
            # generate_dataset_unconditional(
            #     checkpoint_path = Path('/data/vision/___/scratch/___ht/celeba_dir/em-mmps_mask50/checkpoints/checkpoint_8.pkl'),
            #     save_dir = Path('/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/unconditional/em_mmps_mask50_checkpoint_8'),
            #     corruption = 50
            # )

            # moment matching mask 75
            # generate_dataset_unconditional(
            #     checkpoint_path = Path('/data/vision/___/scratch/___ht/celeba_dir/em-mmps_mask75/checkpoints/checkpoint_8.pkl'),
            #     save_dir = Path('/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/unconditional/em_mmps_mask75_checkpoint_8'),
            #     corruption = 75
            # )

            # diff em mask 50
            # generate_dataset_unconditional(
            #     checkpoint_path = Path('/data/vision/___/scratch/___ht/celeba_dir/unconditional-mask50/checkpoints/checkpoint_512_debugged_19.pkl'),
            #     save_dir = Path('/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/unconditional/mask50/mask50_19'),
            #     corruption = 50
            # )

            # diff em mask 75
            # generate_dataset_unconditional(
            #     checkpoint_path = Path('/data/vision/___/scratch/___ht/celeba_dir/unconditional-mask75/checkpoints/checkpoint_512_mask75_checkpoint_23_unconditional.pkl'),
            #     save_dir = Path('/data/vision/___/scratch/___ht/celeba_dir/datasets_for_eval/unconditional/mask75/mask75_23'),
            #     corruption = 75
            # )

