# Evaluating the Conditional Model
import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

from datasets import Array3D, Features, concatenate_datasets, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

# isort: split
from utils import *

CONFIG = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}

config = MyDict(CONFIG)

def save_image(index, acceleration_factor, image_name):
    PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')

    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash('art') % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Data
    dataset = load_from_disk(PATH / f'hf/fastmri-kspace-r{acceleration_factor}')
    dataset.set_format('numpy')

    trainset_yA = dataset['train'] 
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)
    testset_yA = dataset['val']

    y_eval, A_eval = testset_yA[:1024:2]['y'], testset_yA[:1024:2]['A']
    y_eval = jax.vmap(lambda x : jnp.abs(ifft2c(real2complex(x))))(y_eval)
    y_eval, A_eval = jax.device_put((y_eval, A_eval), distributed)

    to_pil(-y_eval[index]).save(f'/data/vision/___/scratch/___ht/fastmri_dir/art/{image_name}{index}.png')

def save_kspace(index, acceleration_factor, image_name):
    PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')

    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash('art') % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Data
    dataset = load_from_disk(PATH / f'hf/fastmri-kspace-r{acceleration_factor}')
    dataset.set_format('numpy')

    trainset_yA = dataset['train'] 
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)
    testset_yA = dataset['val']

    y_eval, A_eval = testset_yA[:1024:2]['y'], testset_yA[:1024:2]['A']
    y_eval = jax.vmap(lambda x : jnp.abs(ifft2c(real2complex(x))))(y_eval)
    y_eval, A_eval = jax.device_put((y_eval, A_eval), distributed)

    to_pil(-y_eval[index]).save(f'/data/vision/___/scratch/___ht/fastmri_dir/art/{image_name}{index}.png')


def save_reconstructed_image(index, acceleration_factor, image_name):
    PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')

    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash('art') % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Data
    dataset = load_from_disk(PATH / f'hf/fastmri-kspace-r{acceleration_factor}')
    dataset.set_format('numpy')

    model = load_module('/data/vision/___/scratch/___ht/fastmri_dir/checkpoints_conditional_ifft_r6_itnog/checkpoint_46.pkl')


    trainset_yA = dataset['train'] 
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)
    testset_yA = dataset['val']

    y_eval, A_eval = testset_yA[:1024:2]['y'], testset_yA[:1024:2]['A']
    y_eval = jax.vmap(lambda x : jnp.abs(ifft2c(real2complex(x))))(y_eval)
    y_eval, A_eval = jax.device_put((y_eval, A_eval), distributed)



    x = sample_conditional(
                model=model,
                y_cond=np.array([y_eval[index], y_eval[index], y_eval[index], y_eval[index]]),
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )[0]

    to_pil(x).save(f'/data/vision/___/scratch/___ht/fastmri_dir/art/{image_name}{index}.png')



if __name__ == "__main__":
    # save_image(2, 6, 'corrupted')
    # save_image(2, 1, 'ground_truth')
    # save_reconstructed_image(2, 6, 'reconstruced')

    # save_image(3, 6, 'corrupted')
    # save_image(3, 1, 'ground_truth')
    # save_reconstructed_image(3, 6, 'reconstruced')

    # save_image(4, 6, 'corrupted')
    # save_image(4, 1, 'ground_truth')
    # save_reconstructed_image(4, 6, 'reconstruced')

    # save_image(5, 6, 'corrupted')
    # save_image(5, 1, 'ground_truth')
    # save_reconstructed_image(5, 6, 'reconstruced')

    # save_image(6, 6, 'corrupted')
    # save_image(6, 1, 'ground_truth')
    # save_reconstructed_image(6, 6, 'reconstruced')

    # save_image(7, 6, 'corrupted')
    # save_image(7, 1, 'ground_truth')
    # save_reconstructed_image(7, 6, 'reconstruced')

    # save_image(8, 6, 'corrupted')
    # save_image(8, 1, 'ground_truth')
    # save_reconstructed_image(8, 6, 'reconstruced')

    for i in range(50):
        # save_image(i, 6, '') # corrupted
        # save_image(i, 1, '') # ground truth
        save_reconstructed_image(i, 6, '') # reconstructed


    # save_kspace()
