import os
import time
import requests
import functools
import jax
from jax import config
import jax.numpy as jnp
import flax
import numpy as onp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import diffusion_distillation
from tensorboardX import SummaryWriter
from tqdm.auto import tqdm
from flax.serialization import to_bytes
import yaml 
import sys
from jax_smi import initialise_tracking
initialise_tracking()
jax.distributed.initialize()

# create model
config = diffusion_distillation.config.cifar_base.get_config()
# if sys.argv[1]=="0":
#     config = diffusion_distillation.config.cifar_cone.get_config()
# if sys.argv[1]=="1":
#     config = diffusion_distillation.config.cifar_cone1.get_config()

model = diffusion_distillation.model.Model(config)
# init params 
state = jax.device_get(model.make_init_state())

state = flax.jax_utils.replicate(state)
# JIT compile training step
train_step = functools.partial(model.step_fn, jax.random.PRNGKey(config.seed), True)
train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)
# run training
pbar = range(10000) # disable on client nodes
if jax.process_index() == 0:
    os.makedirs('logs', exist_ok=True)
    logging_dir = os.path.join('logs', time.strftime('%Y%m%d-%H%M%S'))
    writer = SummaryWriter(logging_dir)
    with open(logging_dir+"/config.yaml", 'w') as f:
        yaml.dump(config.to_dict(), f)
    start_time = time.time()
    pbar = tqdm(pbar)
for step in pbar:
    batch = next(train_iter)
    state, metrics = train_step(state, batch)
    if jax.process_index() == 0:
        metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
        metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
        pbar.set_postfix(metrics)
        # log metrics
        if step % config.train.log_loss_every_steps == 0:
            writer.add_scalar("train/loss",metrics["train/loss"], step)
            writer.add_scalar("train/loss",metrics["train/gnorm"], step)
        # log images
        if step % config.train.log_images_every_steps == 0:
            labels = jnp.zeros((4,), dtype=jnp.int32)
            ema_params = jax.device_get(flax.jax_utils.unreplicate(state.ema_params))
            samples = model.samples_fn(rng=jax.random.PRNGKey(2), labels=labels, params=ema_params, num_steps=4096)
            samples = jax.device_get(samples).astype(onp.uint8)
            writer.add_images('train/samples', samples, step, dataformats='NHWC')
        # checkpoint every 900 seconds
        if time.time() - start_time > config.train.checkpoint_every_secs:
            state_offdevice = jax.device_get(flax.jax_utils.unreplicate(state))
            with open("unet.msgpack", "wb") as f:
                model_bytes = to_bytes(state_offdevice)
                f.write(model_bytes)
            start_time = time.time()
        # evaluate FID
        if step % config.train.eval_every_steps == 0:
            pass
            # fid_bar = tqdm(desc="Computing FID stats...", total=config.train.fid_steps)
            # procs = []
            # for i in range(config.train.fid_steps): # fid_steps * 8 (num_devices) samples
            #     rng, key = jax.random.split(rng)
            #     keys = jax.random.split(key, jax.device_count())
            #     proc = fid_fn_p(init_params_p, jax.lax.stop_gradient(images)) # Inception-Net States
            #     procs.append(proc.squeeze(axis=1).squeeze(axis=1).squeeze(axis=1))
            #     fid_bar.update(1)
            # procs = jnp.concatenate(procs, axis=0)
            # mu = onp.mean(procs, axis=0)
            # sigma = onp.cov(procs, rowvar=False)
            # stats_path = os.path.join(args.output_dir, f'fid_stats_step_{global_step}.npz')
            # onp.savez(stats_path, mu=mu0, sigma=sigma0)
            # print('Saved statistics at:', stats_path, '.')
            # fid_score = diffusion_distillation.inception.fid_score(mu0,mu,sigma0,sigma)
            # writer.add_scalar("train/FID", fid_score, global_step)
            # del procs
            # fid_bar.close()