#!/usr/bin/env python

import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

from datasets import Array3D, Features, concatenate_datasets, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

# isort: split
from utils import *

PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')

CONFIG = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}

CONFIG_FOR_TEST = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (64, 128, 128, 256),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 64,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 64,
    'maxiter': 3,
    # Training
    'epochs': 64,
    'batch_size': 64,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
}

CONFIG = CONFIG_FOR_TEST

def generate(model, dataset, rng, batch_size, config, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        x = sample_conditional(
                model=model,
                y_cond=y,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
            )
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(320, 320, 1), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

lap = 0
runid = wandb.util.generate_id()

run = wandb.init(
    project='priors-conditional-fastmri-kspace',
    id=runid,
    resume='allow',
    dir=PATH,
    config=CONFIG,
)

runpath = PATH / f'runs/{run.name}_{run.id}'
runpath.mkdir(parents=True, exist_ok=True)

config = run.config

# Sharding
jax.config.update('jax_threefry_partitionable', True)

mesh = jax.sharding.Mesh(jax.devices(), 'i')
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

# RNG
seed = hash((runpath, lap)) % 2**16
rng = inox.random.PRNG(seed)

# SDE
sde = VESDE(**CONFIG.get('sde'))

# Data
dataset = load_from_disk(PATH / 'hf/fastmri-kspace')
dataset.set_format('numpy')

# TODO: just for testing let's take the first 1024
trainset_yA = dataset['train'].select(range(1024))
trainset_yA = concatenate_datasets([trainset_yA] * config.duplicate)
testset_yA = dataset['val']

y_eval, A_eval = testset_yA[:1024:256]['y'], testset_yA[:1024:256]['A']
y_eval, A_eval = jax.device_put((y_eval, A_eval), distributed)

# Previous
if lap > 0:
    previous = load_module(PATH / f'checkpoints_conditional/checkpoint_{lap - 1}.pkl')
        # runpath / f'checkpoint_{lap - 1}.pkl')
else:
    y_fit, A_fit = trainset_yA[:16384:4]['y'], trainset_yA[:16384:4]['A']
    y_fit, A_fit = jax.device_put((y_fit, A_fit), distributed)

    mu_x, cov_x = fit_moments(
        features=320 * 320 * 1,
        rank=64,
        shard=True,
        A=inox.Partial(measure, A_fit, shard=True),
        y=flatten(y_fit),
        cov_y=1e-2**2,
        sampler='ddim',
        sde=sde,
        steps=256,
        maxiter=5,
        key=rng.split(),
    )

    del y_fit, A_fit

    previous = ConditionalGaussianDenoiser(mu_x, cov_x)


## Generate latents
## Then corrupt the generated latents

def corruption_transform(row):
    jax.config.update('jax_platform_name', 'cpu')
    x = row['x']
    y = complex2real(fft2c(x))
    A = make_mask(r=6)
    A = np.array(A)
    y = np.random.normal(loc=A * y, scale=1e-2)
    y = np.abs(ifft2c(real2complex(y)))
    return {'A': A, 'y': y}

types = {
    'A': Array3D(shape=(1, 320, 1), dtype='bool'),
    'y': Array3D(shape=(320, 320, 1), dtype='float32'),
}

static, arrays = previous.partition()
arrays = jax.device_put(arrays, replicated)
previous = static(arrays)

testset = generate(
    model=previous,
    dataset=testset_yA,
    rng=rng,
    config = config,
    batch_size=config.batch_size,
    shard=True,
    sampler=config.sampler,
    sde=sde,
    steps=config.discrete,
    maxiter=config.maxiter,
)

testset_corrupted_yA = testset.map(
    corruption_transform,
    features=Features(types),
    remove_columns=['x'],
    num_proc=1,
)

# trainset = generate(
#     model=previous,
#     dataset=trainset_yA,
#     rng=rng,
#     config = config,
#     batch_size=config.batch_size,
#     shard=True,
#     sampler=config.sampler,
#     sde=sde,
#     steps=config.discrete,
#     maxiter=config.maxiter,
# )

# trainset_corrupted_yA = trainset.map(  
#     corruption_transform,
#     features=Features(types),
#     remove_columns=['x'],
#     num_proc=1,
# )

def log_image(x):
    run.log({
        'samples': wandb.Image(to_pil(x)),
    })