import inox
import inox.nn as nn
import jax
import numpy as np
import optax
import wandb
import click

import datetime

from dawgz import job, schedule
from tqdm import tqdm, trange
from typing import *

# isort: split
from utils import *

CONFIG = {
    # Data
    'seed': 0,
    'samples': 65536,
    'features': 5,
    'observe': 2,
    'noise': 1e-2,
    # Architecture
    'features_latent': 5,
    'features_cond': 2 + 5 * 2,
    'hid_features': (256, 256, 256),
    'emb_features': 64,
    'normalize': True,
    # Sampling
    'sampler': 'pc',
    'heuristic': 'cov_x',
    'sde': {'a': 1e-3, 'b': 1e1},
    'discrete': 4096,
    'maxiter': None,
    # Training
    'laps': 64,
    'epochs': 65536,
    'batch_size': 1024,
    'scheduler': 'linear',
    'lr_init': 1e-3,
    'lr_end': 1e-6,
    'lr_warmup': 0.0,
    'optimizer': 'adam',
    'weight_decay': None,
    'clip': 1.0,
}

def fun():
    config = MyDict(CONFIG)

    # RNG
    seed = hash("___") % 2**16
    rng = inox.random.PRNG(seed)

    # SDE
    sde = VESDE(**config.sde)

    # Data
    keys = jax.random.split(jax.random.key(config.seed))

    ## Latent
    x = smooth_manifold(keys[0], shape=(config.samples,), m=1, n=config.features)
    x = (x - x.min(axis=0)) / (x.max(axis=0) - x.min(axis=0))
    x = 4.0 * x - 2.0

    z = show_corner(x)

    breakpoint()

if __name__ == "__main__":
    fun()