# Generating Dataset to find FID


import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

from datasets import Array3D, Features, concatenate_datasets, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

from torch import Tensor
from torch.utils import data

# isort: split
from utils import *

CONFIG = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64 * 4,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}

TEST_MODE = False
PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')

def generate_ifft(model, dataset, rng, batch_size, config, **kwargs):
    """generates data using DiffEM"""
    def transform(batch):
        y, A = batch['y'], batch['A']
        y = np.abs(ifft2c(real2complex(y)))
        x = sample_conditional(
                model=model,
                y_cond=y,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(320, 320, 1), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

def generate_dataset(lap: int, config, rng, sde):
    """generates dataset using DiffEM"""
    dataset = load_from_disk(PATH / 'hf/fastmri-kspace-r6')
    dataset.set_format('numpy')
    
    trainset_yA = dataset['train'].select(range(512)) if TEST_MODE else dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)

    model = load_module(PATH / f'checkpoints_conditional_ifft_r6_itnog/checkpoint_{lap}.pkl')
    
    # Generating 50,000 images
    trainset = generate_ifft(
            model=model,
            dataset=trainset_yA,
            rng=rng,
            config = config,
            batch_size=config.batch_size,
            shard=True,
            sampler=config.sampler,
            sde=sde,
            steps=config.discrete,
            maxiter=config.maxiter,
        )
    images = []
    for i, x in enumerate(trainset['x']):
        to_pil(x).save(f'/data/vision/___/scratch/___ht/fastmri_dir/datasets_for_eval/con_itnog_checkpoint_35/{i+1}.png')

def generate_reference_dataset():
    """Generates the clean image dataset"""
    dataset = load_from_disk(PATH / 'hf/fastmri')
    dataset.set_format('numpy')
    
    trainset_yA = dataset['train'].select(range(512)) if TEST_MODE else dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)

    class Counter:
        def __init__(self):
            self.c = 0
        def inc(self):
            self.c += 1
            return self.c
    
    counter = Counter()


    def transform(row):
        to_pil(row['x']).save(f'/data/vision/___/scratch/___ht/fastmri_dir/datasets_for_eval/ref/{counter.inc()}.png')

    breakpoint()
    trainset_yA.map(
        transform
    )

def generate(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        x = sample(model, y, A, rng.split(), **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(320, 320, 1), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

def generate_dataset_emmmps():
    """
    generates png dataset, using MMPS conditional method
    """
    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    rng = inox.random.PRNG(seed)
    previous = load_module("/data/vision/___/scratch/___ht/fastmri_dir/checkpoints_r6/checkpoint_15.pkl")
    batch_size = 64
    save_path = Path("/data/vision/___/scratch/___ht/fastmri_dir/datasets_for_eval/em-mmps")
    dataset = load_from_disk("/data/vision/___/scratch/___ht/fastmri_dir/hf/fastmri-kspace-r6")
    dataset.set_format('numpy')
    
    trainset_yA = dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * 2)

    print("Length", len(trainset_yA['y']))

    x = generate(
        model=previous,
        dataset=trainset_yA,
        rng=rng,
        batch_size=batch_size,
        shard=True,
        sampler='ddpm',
        sde=sde,
        steps=64,
        maxiter=3,
    )
    try:
        for idx, image_arr in enumerate(x['x']):
            to_pil(image_arr).save(save_path / f"{idx+1}.png")
    except e as Exception:
        breakpoint()
    # pass


def generate_corrupted_png_dataset():
    """
    generates corrupted png dataset using the files in `hf` directory
    """
    path = Path("/data/vision/___/scratch/___ht/fastmri_dir/hf/fastmri-kspace-r6")
    save_path = Path("/data/vision/___/scratch/___ht/fastmri_dir/datasets_for_eval/corrupted-r6")
    save_path.mkdir(parents=True, exist_ok=True)
    dataset = load_from_disk(path)
    dataset.set_format('numpy')

    trainset_yA = dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * 2)
    print('Starting the process...')
    for idx, image_arr in tqdm(enumerate(trainset_yA['y'])):
        image_arr = ifft2c(real2complex(image_arr))
        to_pil(image_arr).save(save_path / f"{idx+1}.png")

def generate_tv_reconstruction_dataset():
    """
    generates tv reconstruction dataset using the files in `hf` directory
    """
    path = Path("/data/vision/___/scratch/___ht/fastmri_dir/hf/fastmri-kspace-r6")
    save_path = Path("/data/vision/___/scratch/___ht/fastmri_dir/datasets_for_eval/tv")
    save_path.mkdir(parents=True, exist_ok=True)
    dataset = load_from_disk(path)
    dataset.set_format('numpy')

    trainset_yA = dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * 2)
    print('Starting the process...')
    for idx, image_arr in tqdm(enumerate(trainset_yA['y'])):
        image_arr_corrupted = ifft2c(real2complex(image_arr))
        image_reconstructed = jax.numpy.zeros(1)
        to_pil(image_reconstructed).save(save_path / f"{idx+1}.png")
    # pass
    pass

if __name__ == "__main__":
    config = MyDict(CONFIG)

    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash('eval_train_unconditional') % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Generate data using the conditional model
    # print('! Generating Dataset')
    # trainset = generate_dataset(
    #             lap=35,
    #             sde=sde,
    #             rng=rng,
    #             config=config
    #             )

    # generate_reference_dataset()
    # generate_corrupted_png_dataset()
    # generate_dataset_emmmps()
