# experiment on the fit_moments

import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb

import regex as re

from datasets import Array3D, Features, concatenate_datasets, load_from_disk
from dawgz import job, schedule
from functools import partial
from tqdm import trange
from typing import *

# isort: split
from utils import *

CONFIG = {
    # Data
    'duplicate': 2,
    # Architecture
    'hid_channels': (128, 256, 384, 512),
    'hid_blocks': (3, 3, 3, 3),
    'kernel_size': (3, 3),
    'emb_features': 256,
    'heads': {3: 4},
    'dropout': 0.1,
    # Sampling
    'sampler': 'ddpm',
    'heuristic': None,
    'sde': {'a': 1e-3, 'b': 1e2},
    'discrete': 256, # default 256
    'maxiter': 3, # default 3
    # Training
    'epochs': 64,
    'batch_size': 256,
    'scheduler': 'constant',
    'lr_init': 1e-4,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
    'ema_decay': 0.999,
    # MMPS
    'rank': 64, # 256, # default: 64
    'fit_size': 1<<16,# 1<<16, # 65536 default: 1<<14=16384
    'cov_y': 1e-2**2,
    'fit_moments_steps': 4,
    "fit_moments_maxiter": 5,
}

RUN_NAME = "cov3_rank96"
TEST_MODE = False

PATH = Path('/data/vision/___/scratch/___ht/celeba_dir/em-mmps-test') if TEST_MODE \
  else Path('/data/vision/___/scratch/___ht/celeba_dir/em-mmps' + f'_{RUN_NAME}')

DATASET_PATH = '/data/vision/___/scratch/___ht/celeba_64_mask50/' if TEST_MODE==False \
                else '/data/vision/___/scratch/___ht/celeba_64_mask50_test/'

PATH.mkdir(parents=True, exist_ok=True)
(Path(PATH) / 'checkpoints').mkdir(parents=True, exist_ok=True)

def generate(model, dataset, rng, batch_size, **kwargs):
    def transform(batch):
        y, A = batch['y'], batch['A']
        x = sample(model, y, A, rng.split(), **kwargs)
        x = np.asarray(x)

        return {'x': x}

    types = {'x': Array3D(shape=(64, 64, 3), dtype='float32')}

    return dataset.map(
        transform,
        features=Features(types),
        remove_columns=['y', 'A'],
        keep_in_memory=True,
        batched=True,
        batch_size=batch_size,
        drop_last_batch=True,
    )

def train(runid: int, lap: int):
    run = wandb.init(
        project='fit_moments_experiment',
        id=runid,
        resume='allow',
        dir=PATH,
        name=f'{CONFIG["rank"]}_{np.log2(CONFIG["fit_size"])}_{CONFIG["cov_y"]}_{CONFIG["fit_moments_steps"]}',
        config=CONFIG,
    )

    runpath = PATH / f'runs/{run.name}_{run.id}'
    runpath.mkdir(parents=True, exist_ok=True)

    config = run.config

    # Sharding
    jax.config.update('jax_threefry_partitionable', True)

    mesh = jax.sharding.Mesh(jax.devices(), 'i')
    replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
    distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))

    # RNG
    seed = hash((runpath, lap)) % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**CONFIG.get('sde'))

    # Data
    dataset = load_from_disk(DATASET_PATH)
    dataset.set_format('numpy')

    def normalize_map(row):
        x, A = row['y'], row['A']
        x = (x * 4 / 256) - 2
        x = A * x # the corrupted pixels are set to zero instead of -2 (TODO: Does this really matter?)
        return {'y': x, 'A': A}

    dataset = dataset.map(
        normalize_map
    )

    trainset_yA = dataset

    y_fit, A_fit = trainset_yA[:config.fit_size]['y'], trainset_yA[:config.fit_size]['A']
    y_fit, A_fit = jax.device_put((y_fit, A_fit), distributed)

    print("Starting fit_moments...")

    mu_x, cov_x = fit_moments(
        features=64 * 64 * 3,
        rank=CONFIG['rank'],
        shard=True,
        A=inox.Partial(measure, A_fit),
        y=flatten(y_fit),
        cov_y=CONFIG['cov_y'],
        sampler='ddim',
        sde=sde,
        steps=CONFIG['fit_moments_steps'],
        maxiter=CONFIG["fit_moments_maxiter"],
        key=rng.split(),
    )

    del y_fit, A_fit

    print("Finished fit_moments.")

    previous = GaussianDenoiser(mu_x, cov_x)

    ## Generate
    static, arrays = previous.partition()
    arrays = jax.device_put(arrays, replicated)
    previous = static(arrays)

    trainset_yA = trainset_yA.select(range(1<<10))

    trainset = generate(
        model=previous,
        dataset=trainset_yA,
        rng=rng,
        batch_size=config.batch_size,
        shard=True,
        sampler=config.sampler,
        sde=sde,
        steps=config.discrete,
        maxiter=config.maxiter,
    )

    # Log an image to wandb
    sample_image = to_pil(trainset['x'][:16].reshape(4, 4, 64, 64, 3))
    wandb.log({"sample_image": wandb.Image(sample_image)})

if __name__ == '__main__':
    runid = wandb.util.generate_id()

    train(runid = runid, lap = 0)

