#!/usr/bin/env python

import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

from datasets import Array3D, Features, concatenate_datasets, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

# isort: split
from utils import *

PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')
TEST_MODE = False
CONFIG = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64 * 4,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}


def generate_ifft(model, dataset, rng, batch_size, config, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        y = np.abs(ifft2c(real2complex(y)))
        x = sample_conditional(
                model=model,
                y_cond=y,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(320, 320, 1), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )


def generate_dataset(lap: int, config, rng, sde):

    dataset = load_from_disk(PATH / 'hf/fastmri-kspace-r6')
    dataset.set_format('numpy')
    
    trainset_yA = dataset['train'].select(range(512)) if TEST_MODE else dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)

    model = load_module(PATH / f'checkpoints_conditional_ifft_itnog/checkpoint_{lap}.pkl')
    trainset = generate_ifft(
            model=model,
            dataset=trainset_yA,
            rng=rng,
            config = config,
            batch_size=config.batch_size,
            shard=True,
            sampler=config.sampler,
            sde=sde,
            steps=config.discrete,
            maxiter=config.maxiter,
        )
    return trainset

def train(trainset, config, sde, rng, distributed, replicated, save_path):
    runid = wandb.util.generate_id()

    run = wandb.init(
        project='priors-fastmri-kspace-unconditional',
        id=runid,
        resume='allow',
        dir=PATH,
        config=CONFIG,
    )
    model = make_model(key=rng.split(), **CONFIG)
    model.train(True)

    static, params, others = model.partition(nn.Parameter)

    # Objective
    objective = DenoiserLoss(sde=sde)

    # Optimizer
    steps = config.epochs * len(trainset) // config.batch_size
    optimizer = Adam(
                steps=steps,
                scheduler = config.scheduler,
                lr_init = config.lr_init,
                lr_end = config.lr_end,
                lr_warmup = config.lr_warmup,
                weight_decay = config.weight_decay,
                clip = config.clip,
                )
    opt_state = optimizer.init(params)

    # EMA
    ema = EMA(decay=config.ema_decay)
    avrg = params

    # Training
    avrg, params, others, opt_state = jax.device_put((avrg, params, others, opt_state), replicated)

    @jax.jit
    @jax.vmap
    def augment(x, key):
        keys = jax.random.split(key, 2)

        x = random_flip(x, keys[0], axis=-2)
        x = random_shake(x, keys[1], delta=4)

        return x

    @jax.jit
    def ell(params, others, x, key):
        keys = jax.random.split(key, 3)

        z = jax.random.normal(keys[0], shape=x.shape)
        t = jax.random.beta(keys[1], a=3, b=3, shape=x.shape[:1])

        return objective(static(params, others), x, z, t, key=keys[2])

    @jax.jit
    def sgd_step(avrg, params, others, opt_state, x, key):
        loss, grads = jax.value_and_grad(ell)(params, others, x, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        avrg = ema(avrg, params)

        return loss, avrg, params, opt_state

    for epoch in (bar := trange(config.epochs, ncols=88)):
        loader = trainset.shuffle(seed=seed * config.epochs + epoch).iter(
            batch_size=config.batch_size, drop_last_batch=True
        )

        losses = []

        for batch in prefetch(loader):
            x = batch['x']
            x = jax.device_put(x, distributed)
            x = augment(x, rng.split(len(x)))
            x = flatten(x)

            loss, avrg, params, opt_state = sgd_step(avrg, params, others, opt_state, x, key=rng.split())
            losses.append(loss)

        loss_train = np.stack(losses).mean()
        
        bar.set_postfix(loss=loss_train)

        ## Eval
        if (epoch + 1) % 16 == 0:
            model = static(avrg, others)
            model.train(False)

            x = sample(
                model=model,
                key=rng.split(),
                y = None,
                A = None,
                cnt = 4,
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
            x = x.reshape(2, 2, 320, 320, 1)

            run.log({
                'loss': loss_train,
                'samples': wandb.Image(to_pil(x)),
            })
        else:
            run.log({
                'loss': loss_train,
            })
    dump_module(model, save_path)

if __name__ == "__main__":
    config = MyDict(CONFIG)

    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash('eval_train_unconditional') % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Generate data using the conditional model
    print('! Generating Dataset')
    trainset = generate_dataset(
                lap=35,
                sde=sde,
                rng=rng,
                config=config
                )
    print('! Starting Training')
    train(
        trainset=trainset,
        sde=sde,
        config=config,
        rng=rng,
        distributed=distributed,
        replicated=replicated,
        save_path = '/data/vision/___/scratch/___ht/fastmri_dir/unconditional_model_evaluation/unconditional_35.pkl'
        )