"""
The difference between this and train_conditional.py is
that we diffuse the k-space and also condition on the
k-space.
"""

import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

from datasets import Array3D, Features, concatenate_datasets, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

import pathlib
import re

# isort: split
from utils2 import *

# set these
TEST_MODE = True
CORRUPTION = 6
RUN_NAME = (f'diffem_kspace2_r{CORRUPTION}' + ('_test' if TEST_MODE else ''))

# don't touch these
PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir/')
DATASET_PATH = PATH / (f'hf/fastmri-kspace-r{CORRUPTION}' + ('-test' if TEST_MODE else ''))
PATH = PATH / RUN_NAME
CHECKPOINT_PATH = PATH / 'checkpoints'
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

CONFIG = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64 if not TEST_MODE else 17,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}

def generate_emmmps(model, dataset, rng, batch_size, config, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        x = sample_conditional(
                model=model,
                y_cond=y,
                key=rng.split(),
                shard=True,
                shape=320 * 320 * 1,  # for em-mmps we are generating in the image space
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(320, 320, 1), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

def generate_ifft(model, dataset, rng, batch_size, config, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        A = np.broadcast_to(A, (y.shape[0], 320, 320, 1)).astype(y.dtype)
        y = np.concatenate([y, A], axis = -1)

        x = sample_conditional(
                model=model,
                y_cond=y,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(320, 320, 2), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )


def train(runid: str, lap: int):
    run = wandb.init(
        project='priors-conditional-fastmri-kspace',
        id=str(runid),
        resume='allow',
        dir=PATH,
        config=CONFIG,
        name=f'kspace_r{CORRUPTION}_lap_[{lap}, )' + ('_test' if TEST_MODE else '')
    )

    if run is None:
        raise ValueError('WandB did not initialize properly.')

    ifft2cv = jax.vmap(ifft2c)
    real2complexv = jax.vmap(real2complex)

    gpu_count = len(jax.devices())

    runpath = PATH / f'runs/{run.name}_{run.id}'
    runpath.mkdir(parents=True, exist_ok=True)

    config = run.config

    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash((runpath, lap)) % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Data
    dataset = load_from_disk(DATASET_PATH)
    dataset.set_format('numpy')
    
    trainset_yA = dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate) # TODO: Why are we doing this?
    testset_yA = dataset['val']

    if gpu_count == 8:
        y_eval, A_eval = testset_yA[:1024:128]['y'], testset_yA[:1024:128]['A']
    else:
        y_eval, A_eval = testset_yA[:1024:256]['y'], testset_yA[:1024:256]['A']
    
    y_eval, A_eval = jax.device_put((y_eval, A_eval), distributed)
    A_eval = jnp.broadcast_to(A_eval, (y_eval.shape[0], 320, 320, 1)).astype(y_eval.dtype)
    yA_eval = jnp.concatenate([y_eval, A_eval], axis = -1)


    # Previous
    if lap > 0:
        previous = load_module(CHECKPOINT_PATH / f'checkpoint_{lap - 1}.pkl')
    else:
        y_fit, A_fit = trainset_yA[:16384:4]['y'], trainset_yA[:16384:4]['A']
        y_fit, A_fit = jax.device_put((y_fit, A_fit), distributed)

        mu_x, cov_x = fit_moments(
            features=320 * 320 * 1,
            rank=64,
            shard=True,
            A=inox.Partial(measure, A_fit, shard=True),
            y=flatten(y_fit),
            cov_y=1e-2**2,
            sampler='ddim',
            sde=sde,
            steps=256,
            maxiter=5,
            key=rng.split(),
            iterations=1 if TEST_MODE else 16, # default
        )

        del y_fit, A_fit

        previous = ConditionalGaussianDenoiser(mu_x, cov_x) # no different than gaussian denoiser


    ## Generate latents
    ## Then corrupt the generated latents

    def corruption_transform(row):
        """
        takes a k-space and applies the corruption mask on it.
        Returns the k-space back.
        """
        # jax.config.update('jax_platform_name', 'cpu')

        x = row['x']

        A = make_mask(r=CORRUPTION)
        A = np.array(A)
        x = np.random.normal(loc=A * x, scale=1e-2)

        return {'A': A, 'y': x}

    types = {
        'A': Array3D(shape=(1, 320, 1), dtype='bool'),
        'y': Array3D(shape=(320, 320, 2), dtype='float32'),
    }

    static, arrays = previous.partition()
    arrays = jax.device_put(arrays, replicated)
    previous = static(arrays)

    if lap > 0:
        testset = generate_ifft(
            model=previous,
            dataset=testset_yA,
            rng=rng,
            config = config,
            batch_size=config.batch_size,
            shard=True,
            sampler=config.sampler,
            sde=sde,
            steps=config.discrete,
            maxiter=config.maxiter,
        )

        trainset = generate_ifft(
            model=previous,
            dataset=trainset_yA,
            rng=rng,
            config = config,
            batch_size=config.batch_size,
            shard=True,
            sampler=config.sampler,
            sde=sde,
            steps=config.discrete,
            maxiter=config.maxiter,
        )
    else:
        # Using EM-MMPS to generate samples in the image space
        testset = generate_emmmps(
            model=previous,
            dataset=testset_yA,
            rng=rng,
            config = config,
            batch_size=config.batch_size,
            shard=True,
            sampler=config.sampler,
            sde=sde,
            steps=config.discrete,
            maxiter=config.maxiter,
        )

        trainset = generate_emmmps(
            model=previous,
            dataset=trainset_yA,
            rng=rng,
            config = config,
            batch_size=config.batch_size,
            shard=True,
            sampler=config.sampler,
            sde=sde,
            steps=config.discrete,
            maxiter=config.maxiter,
        )

        # switching the samples from the image space to the k-space

        def fft_transform(row):
            x = row['x'] # image (320, 320, 1)
            x = np.array(complex2real(fft2c(x)))
            return {'x_fft': x}
        
        types_fft_transform = {
            'x_fft': Array3D(shape=(320, 320, 2), dtype='float32'),
        }

        testset = testset.map(
            fft_transform,
            features=Features(types_fft_transform),
            remove_columns=['x'],
            num_proc=1,
        )
        trainset = trainset.map(
            fft_transform,
            features=Features(types_fft_transform),
            remove_columns=['x'],
            num_proc=1,
        )

        # change the name of 'x_fft' to 'x' in both trainset and testset
        testset = testset.rename_column('x_fft', 'x')
        trainset = trainset.rename_column('x_fft', 'x')


    # Log the first 16 samples of the trainset
    samples = ifft2cv(real2complexv(trainset['x'][:16])).real
    run.log({
        'trainset_samples': wandb.Image(to_pil(samples.reshape(4, 4, 320, 320, 1), zoom = 4))
        })

    testset_corrupted_yA = testset.map(
        corruption_transform,
        features=Features(types),
        remove_columns=['x'],
        num_proc=1,
    )

    trainset_corrupted_yA = trainset.map(  
        corruption_transform,
        features=Features(types),
        remove_columns=['x'],
        num_proc=1,
    )

    # Model
    if lap > 0:
        model = previous
    else:
        model = make_model_conditional(key=rng.split(), **CONFIG)
    
    model.train(True)

    static, params, others = model.partition(nn.Parameter)

    # Objective
    objective = ConditionalDenoiserLoss(sde=sde)

    # Optimizer
    steps = config.epochs * len(dataset) // config.batch_size
    optimizer = Adam(steps=steps, **config)
    opt_state = optimizer.init(params)

    # EMA
    ema = EMA(decay=config.ema_decay)
    avrg = params

    # Training
    avrg, params, others, opt_state = jax.device_put((avrg, params, others, opt_state), replicated)

    @jax.jit
    @jax.vmap
    def augment(x, key):
        keys = jax.random.split(key, 2)

        x = random_flip(x, keys[0], axis=-2)
        x = random_shake(x, keys[1], delta=4)

        return x

    @jax.jit
    def ell(params, others, x, y_cond, key):
        keys = jax.random.split(key, 3)

        z = jax.random.normal(keys[0], shape=x.shape)
        t = jax.random.beta(keys[1], a=3, b=3, shape=x.shape[:1])

        return objective(static(params, others), x, z, t, y_cond, key=keys[2])

    @jax.jit
    def sgd_step(avrg, params, others, opt_state, x, y_cond, key):
        loss, grads = jax.value_and_grad(ell)(params, others, x, y_cond, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        avrg = ema(avrg, params)

        return loss, avrg, params, opt_state

    if not isinstance(config.epochs, int):
        raise ValueError('Number of epochs must be an integer.')

    for epoch in (bar := trange(config.epochs, ncols=88, desc="Training Progress")):
        loader = trainset.shuffle(seed=seed + lap * config.epochs + epoch).iter(
            batch_size=config.batch_size, drop_last_batch=True
        )

        loader_yA = trainset_corrupted_yA.shuffle(seed=seed + lap * config.epochs + epoch).iter(
            batch_size=config.batch_size, drop_last_batch=True
        )

        losses = []

        for batch_x, batch_y in zip(prefetch(loader), prefetch(loader_yA)):
            x = batch_x['x']
            y = batch_y['y']
            A = batch_y['A']
            aug_key = rng.split(len(x))

            A = jax.device_put(A, distributed)
            A = jnp.broadcast_to(A, (y.shape[0], 320, 320, 1)).astype(y.dtype)
            A = augment(A, aug_key)
            A = flatten(A)

            x = jax.device_put(x, distributed)
            x = augment(x, aug_key)
            x = flatten(x)
            
            y = jax.device_put(y, distributed)
            y = augment(y, aug_key)
            y = flatten(y)
            y = jnp.concatenate([y, A], axis = 1)

            loss, avrg, params, opt_state = sgd_step(avrg, params, others, opt_state, x, y, key=rng.split())
            losses.append(loss)

        loss_train = np.stack(losses).mean()

        ## Validation
        loader = testset.iter(batch_size=config.batch_size, drop_last_batch=True)
        loader_yA = testset_corrupted_yA.iter(batch_size=config.batch_size, drop_last_batch=True)
        losses = []

        for batch_x, batch_y in zip(prefetch(loader), prefetch(loader_yA)):
            x = batch_x['x']
            x = jax.device_put(x, distributed)
            x = flatten(x)

            y_cond = batch_y['y']
            y_cond = jax.device_put(y_cond, distributed)
            y_cond = flatten(y_cond)

            A = batch_y['A']
            A = jax.device_put(A, distributed)
            A = jnp.broadcast_to(A, (y_cond.shape[0], 320, 320, 1)).astype(y_cond.dtype)

            y_cond = jnp.concatenate([y_cond, flatten(A)], axis = 1)

            loss = ell(avrg, others, x, y_cond, key=rng.split())
            losses.append(loss)

        loss_val = np.stack(losses).mean()

        bar.set_postfix(loss=loss_train, loss_val=loss_val)

        ## Eval
        if (epoch + 1) % 16 == 0:
            model = static(avrg, others)
            model.train(False)

            x = sample_conditional(
                model=model,
                y_cond=yA_eval,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
            
            x = x.reshape(-1, 320, 320, 2)
            x = ifft2cv(real2complexv(x)).real
            x = x.reshape(-1, 320, 320, 1)

            x = jnp.concatenate([ifft2cv(real2complexv(y_eval)).reshape(-1, 320, 320, 1).real, x], axis=0)
            
            run.log({
                'loss': loss_train,
                'loss_val': loss_val,
                'samples': wandb.Image(to_pil(x, zoom = 4 if x.shape[0] == 16 else 2)),
            })
        else:
            run.log({
                'loss': loss_train,
                'loss_val': loss_val,
            })

    ## Checkpoint
    model = static(avrg, others)
    model.train(False)

    dump_module(model, CHECKPOINT_PATH / f'checkpoint_{lap}.pkl')


if __name__ == '__main__':
    runid = wandb.util.generate_id()
    print('ITNOG')

    jobs = []

    # Automatically finding the start_lap by looping over the files in directory
    checkpoint_name_regex = r'checkpoint_(\d+).pkl'
    max_checkpoint = -1
    for child in Path(CHECKPOINT_PATH).iterdir():
        if not child.is_file():
            continue
        if not re.fullmatch(checkpoint_name_regex, child.name):
            continue
        max_checkpoint = max(max_checkpoint, int(re.fullmatch(checkpoint_name_regex, child.name).group(1)))
        
    start_lap = max_checkpoint + 1

    for lap in range(start_lap, 128):
        print(f'Lap {lap} started...')
        train(runid = runid, lap = lap)
    
