# EM-MMPS on CIFAR with blurry corruption
# Very much not working (we need ablations and FLOP analysis to show that it really sucks)
import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb
import einops
from skimage.filters import gaussian

import re

from datasets import Array3D, Features, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

from imagecorruptions import gaussian_blur as corrupt_uint

from skimage.filters import gaussian
from jax.scipy.signal import convolve2d

from utils import *

CONFIG_ORIGINAL = {
    # Data
    'corruption': 2, # sigma
    'kernel_size': 5,
    # Architecture
    'hid_channels': (128, 256, 384),
    'hid_blocks': (5, 5, 5),
    'emb_features': 256,
    'heads': {1: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'sde': {'a': 1e-3, 'b': 1e2},
    'heuristic': None,
    'discrete': 256,
    'maxiter': 1,
    # Training
    'epochs': 256,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 2e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.9999,
}

TEST_MODE = True

CONFIG = CONFIG_ORIGINAL
corruption_severity = CONFIG['corruption']
kernel_size = CONFIG['kernel_size']
DATASET_PATH = f'/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-gaussian-blur-s{corruption_severity}-k{kernel_size}'
PATH = Path('/data/vision/___/scratch/___ht/cifar_dir/itnog_EMMMPS_blur')

@jax.jit
def gaussian_2d_kernel(size: int, sigma: float) -> Array:
    r"""Generates a 2D Gaussian kernel."""
    ax = jnp.linspace(-(size // 2), size // 2, size)
    xx, yy = jnp.meshgrid(ax, ax)
    kernel = jnp.exp(-(xx**2 + yy**2) / (2 * sigma**2))
    kernel /= jnp.sum(kernel)
    return kernel.astype(jnp.float32)

@jax.vmap
def corrupt_func(x):
    r"""
    if wrapped with vmap:       x.shape = (batch_size, 32 * 32 * 3)
    if not wrapped with vmap:   x.shape = (32 * 32, 3)
    """
    x = unflatten(x, 32, 32)

    A = jax.random.bernoulli(jax.random.PRNGKey(0), p=0.5, shape=x.shape)
    x = A * x

    # kernel_size = 5
    # sigma = 2

    # ax = jnp.linspace(-(kernel_size//2), kernel_size//2, kernel_size)
    # xx, yy = jnp.meshgrid(ax, ax)
    # kernel = jnp.exp(-(xx**2+yy**2)/(2*sigma**2))
    # kernel /= kernel.sum()

    # x = jnp.stack([
    #     convolve2d(x[:, :, 0], kernel, mode='same', boundary='fill', fillvalue=0),
    #     convolve2d(x[:, :, 1], kernel, mode='same', boundary='fill', fillvalue=0),
    #     convolve2d(x[:, :, 2], kernel, mode='same', boundary='fill', fillvalue=0)
    # ])

    x = flatten(x)
    return x
    

def generate(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y = batch['y']
        x = sample(model, y, corrupt_func, rng.split(), A_is_func=True, **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(32, 32, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )


def train(runid: int, lap: int):

    run = wandb.init(
        project='priors-cifar-mask',
        id=str(runid),
        resume='allow',
        dir=PATH,
        config=CONFIG,
        name = f"[{lap}, ) EM-MMPS-blurry-severity-{corruption_severity}-k{kernel_size}",
    )

    if run == None:
        print("Run failed to initialize.")
        return

    runpath = PATH / f'runs/{run.name}_{run.id}'
    runpath.mkdir(parents=True, exist_ok=True)

    config = run.config

    config = Config(CONFIG)

    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash(('EM-MMPS-blurry', lap)) % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Data
    dataset = load_from_disk(DATASET_PATH)
    dataset.set_format('numpy')

    trainset_yA = dataset['train']
    testset_yA = dataset['test']

    # trainset_yA is Dataset({
    #     features: ['y'],
    #     num_rows: 50000
    # })

    if TEST_MODE:
        trainset_yA = trainset_yA.select(range(1<<13))
        testset_yA = testset_yA.select(range(1<<13))

    y_eval = testset_yA[:16]['y']
    y_eval = jax.device_put(y_eval, distributed)

    # Previous
    if lap > 0:
        previous = load_module(PATH / f'checkpoints/checkpoint_{lap - 1}.pkl')
    else:
        y_fit = trainset_yA[:16384]['y']
        y_fit = jax.device_put(y_fit, distributed)

        # A_fit_moments = inox.Partial(measure2, A_fit)
        print(f"y_fit shape: {y_fit.shape}")
        breakpoint()

        mu_x, cov_x = fit_moments(
            features=32 * 32 * 3,
            rank=64,
            shard=True,
            A=corrupt_func, # function from x -> y such that x.shape = (batch_size, 32 * 32 * 3)
            y=flatten(y_fit), # y.shape = (batch_size, 32 * 32 * 3)
            cov_y=1e-3**2, # TODO: or 1e-2**2
            sampler='ddim',
            sde=sde,
            steps=1,# 256,
            maxiter=None,
            key=rng.split(),
        )

        del y_fit

        breakpoint()

        previous = GaussianDenoiser(mu_x, cov_x)

    ## Generate
    static, arrays = previous.partition()
    arrays = jax.device_put(arrays, replicated)
    previous = static(arrays)

    trainset = generate(
        model=previous,
        dataset=trainset_yA,
        rng=rng,
        batch_size=config.batch_size,
        shard=True,
        sampler=config.sampler,
        sde=sde,
        steps=config.discrete,
        maxiter=config.maxiter,
    )
    testset = generate(
        model=previous,
        dataset=testset_yA,
        rng=rng,
        batch_size=config.batch_size,
        shard=True,
        sampler=config.sampler,
        sde=sde,
        steps=config.discrete,
        maxiter=config.maxiter,
    )

    ## Moments
    x_fit = trainset[:16384]['x']
    x_fit = flatten(x_fit)

    mu_x, cov_x = ppca(x_fit, rank=64, key=rng.split())

    del x_fit

    # Model
    if lap > 0:
        model = previous
    else:
        model = make_model(key=rng.split(), **CONFIG)

    model.mu_x = mu_x

    if config.heuristic == 'zeros':
        model.cov_x = jnp.zeros_like(mu_x)
    elif config.heuristic == 'ones':
        model.cov_x = jnp.ones_like(mu_x)
    elif config.heuristic == 'cov_t':
        model.cov_x = jnp.ones_like(mu_x) * 1e6
    elif config.heuristic == 'cov_x':
        model.cov_x = cov_x

    model.train(True)

    static, params, others = model.partition(nn.Parameter)

    # Objective
    objective = DenoiserLoss(sde=sde)

    # Optimizer
    steps = config.epochs * len(trainset_yA) // config.batch_size
    optimizer = Adam(
            steps=steps,
            scheduler=config.scheduler,
            lr_init=config.lr_init,
            lr_end=config.lr_end,
            lr_warmup=config.lr_warmup,
            weight_decay=config.weight_decay,
            clip=config.clip,
    )
    opt_state = optimizer.init(params)

    # EMA
    ema = EMA(decay=config.ema_decay)
    avrg = params

    # Training
    avrg, params, others, opt_state = jax.device_put((avrg, params, others, opt_state), replicated)

    @jax.jit
    @jax.vmap
    def augment(x, key):
        keys = jax.random.split(key, 3)

        x = random_flip(x, keys[0], axis=-2)
        x = random_hue(x, keys[1], delta=1e-2)
        x = random_saturation(x, keys[2], lower=0.95, upper=1.05)

        return x

    @jax.jit
    def ell(params, others, x, key):
        keys = jax.random.split(key, 3)

        z = jax.random.normal(keys[0], shape=x.shape)
        t = jax.random.beta(keys[1], a=3, b=3, shape=x.shape[:1])

        return objective(static(params, others), x, z, t, key=keys[2])

    @jax.jit
    def sgd_step(avrg, params, others, opt_state, x, key):
        loss, grads = jax.value_and_grad(ell)(params, others, x, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        avrg = ema(avrg, params)

        return loss, avrg, params, opt_state

    for epoch in (bar := trange(config.epochs, ncols=88)):
        loader = trainset.shuffle(seed=seed + lap * config.epochs + epoch).iter(
            batch_size=config.batch_size, drop_last_batch=True
        )

        losses = []

        for batch in prefetch(loader):
            x = batch['x']
            x = jax.device_put(x, distributed)
            x = augment(x, rng.split(len(x)))
            x = flatten(x)

            loss, avrg, params, opt_state = sgd_step(avrg, params, others, opt_state, x, key=rng.split())
            losses.append(loss)


        loss_train = np.stack(losses).mean()

        ## Validation
        loader = testset.iter(batch_size=config.batch_size, drop_last_batch=True)

        losses = []

        for batch in prefetch(loader):
            x = batch['x']
            x = jax.device_put(x, distributed)
            x = flatten(x)

            loss = ell(avrg, others, x, key=rng.split())
            losses.append(loss)

        loss_val = np.stack(losses).mean()

        bar.set_postfix(loss=loss_train, loss_val=loss_val)

        ## Eval
        if (epoch + 1) % 16 == 0:
            model = static(avrg, others)
            model.train(False)

            x = sample(
                model=model,
                y=y_eval,
                A=corrupt_func,
                key=rng.split(),
                shard=True,
                sampler=config.sampler,
                steps=config.discrete,
                maxiter=config.maxiter,
                A_is_func=True,
            )
            x = x.reshape(4, 4, 32, 32, 3)

            run.log({
                'loss': loss_train,
                'loss_val': loss_val,
                'samples': wandb.Image(to_pil(x, zoom=4)),
            })
        else:
            run.log({
                'loss': loss_train,
                'loss_val': loss_val,
            })

    ## Checkpoint
    model = static(avrg, others)
    model.train(False)

    dump_module(model, PATH / f'checkpoints/checkpoint_{lap}.pkl')

if __name__ == '__main__':

    runid = wandb.util.generate_id()

    start_lap = 0

    checkpoint_directory = PATH / 'checkpoints'
    
    for child in checkpoint_directory.iterdir():
        match = re.fullmatch(r'checkpoint_(\d+)\.pkl', child.name)
        if not match:
            continue
        start_lap = max(start_lap, 1 + int(match.group(1)))

    print(f'Start Lap: {start_lap}')

    for lap in range(start_lap, 32):
        train(runid = runid, lap = lap)
        
