"""
main_qm9.py
  run_lib.py for qm9 dataset.
"""
import jax_smi
import datetime
import wandb
import os
import tensorflow as tf
import jax
import orbax
import numpy as np
import jax.numpy as jnp
from flax.training import checkpoints
import functools
import optax
import ml_collections
import flax
from typing import Callable, Any
from losses import get_optimizer
from models.utils import variable_ema
from flax.core.frozen_dict import FrozenDict
import logging
import flax.linen as nn

# #*=======================================================================#
# # Training NN model.
# class EGNNModel(nn.Module):
#   config: ml_collections.ConfigDict

#   @nn.compact
#   def __call__(self, batch):
#     config = self.config

#     # Model configuration

#     # Model forward

#*=======================================================================#
def get_qm9_state(rng, config, dummy_loader):

  # Get model parameters and definitions
  model_name = config.model.name
  model = EGNNModel(config=config)
  fake_batch = generate_batch(dummy_loader)
  variables = model.init(rng, fake_batch, train=False)

  # Create optimizer
  optimizer = get_optimizer(config)
  optimizer_ema = variable_ema(initial_count=config.model.initial_count) # customized API for variable rate
  class TrainState(flax.struct.PyTreeNode):
    step: int
    apply_fn: Callable = flax.struct.field(pytree_node=False)
    params: FrozenDict[str, Any]
    tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
    tx_ema: optax.GradientTransformation = flax.struct.field(pytree_node=False)
    opt_state: optax.OptState
    opt_state_ema: optax.OptState
    dropout_rng: jax.Array

    def apply_gradients(self, *, grads, **kwargs):
      """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.

      Note that internally this function calls `.tx.update()` followed by a call
      to `optax.apply_updates()` to update `params` and `opt_state`.

      Args:
        grads: Gradients that have the same pytree structure as `.params`.
        **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

      Returns:
        An updated instance of `self` with `step` incremented by one, `params`
        and `opt_state` updated by applying `grads`, and additional attributes
        replaced as specified by `kwargs`.
      """
      updates, new_opt_state = self.tx.update(
          grads, self.opt_state, self.params)
      new_params = optax.apply_updates(self.params, updates)
      return self.replace(
          params=new_params,
          opt_state=new_opt_state,
          **kwargs,
      )

    @classmethod
    def create(cls, *, apply_fn, params, tx, tx_ema, **kwargs):
      """Creates a new instance with `step=0` and initialized `opt_state`."""
      opt_state = tx.init(params)
      opt_state_ema = tx_ema.init(params)
      return cls(
          step=0,
          apply_fn=apply_fn,
          params=params,
          tx=tx,
          tx_ema=tx_ema,
          opt_state=opt_state,
          opt_state_ema=opt_state_ema,
          **kwargs,
      )
  return TrainState.create(
    apply_fn=model.apply,
    params=variables['params'], # main parameter
    tx=optimizer, # (Adam) Optimizer
    tx_ema=optimizer_ema, # EMA state that includes delayed EMA parameter
  )
#*=======================================================================#
# Set train_state from the NN model.
def train(config, workdir, log_name):
  # ====================================================================================================== #
  # Get logger
  jax_smi.initialise_tracking()

  # wandb_dir: Directory of wandb summaries
  current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  if log_name is None:
    wandb.init(project="anonymous-repo-qm9", name=f"{config.model.name}-{current_time}", entity="anonymous", resume="allow")
  else:
    wandb.init(project="anonymous-repo-qm9", name=log_name, entity="anonymous", resume="allow")
  wandb_dir = os.path.join(workdir, "wandb")
  tf.io.gfile.makedirs(wandb_dir)
  wandb.config = config

  # Create directories for experimental logs
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)
  rng = jax.random.PRNGKey(config.seed)
  # ====================================================================================================== #
  # Initialize model.
  rng, step_rng = jax.random.split(rng)
  state = get_qm9_state(step_rng, config)
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  state_dict = {
    'model': state
  }

  # get manager options, and restore checkpoints.
  mgr_options = orbax.checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.training.snapshot_freq,
    create=True)
  ckpt_mgr = orbax.checkpoint.CheckpointManager(
    checkpoint_dir,
    orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)

  # Resume training when intermediate checkpoints are detected
  mgr_meta_options = orbax.checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.training.snapshot_freq_for_preemption,
    max_to_keep=1,
    create=True)
  ckpt_meta_mgr = orbax.checkpoint.CheckpointManager(
    checkpoint_meta_dir,
    orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_meta_options)
  if ckpt_meta_mgr.latest_step() is not None:
    logging.info(f"Restore checkpoint-meta from step {ckpt_meta_mgr.latest_step()}.")
    state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(checkpoint_meta_dir, f"{ckpt_meta_mgr.latest_step()}", "default"), target=state_dict)
    state = state_dict['model']

  # `state.step` is JAX integer on the GPU/TPU devices
  initial_step = int(state.step)

  # # Define dense_state: a model with single linear layer for random projection of auxiliary variables.
  # Cin, Cout = config.data.num_channels, config.model.aug_dim
  # if os.path.exists(os.path.join(checkpoint_dir, "dense.npy")):
  #   kernel_arr = np.load(os.path.join(checkpoint_dir, "dense.npy"))
  # else:
  #   rng, step_rng = jax.random.split(rng)
  #   kernel_arr = jax.random.normal(step_rng, shape=(Cin, Cout))
  #   np.save(os.path.join(checkpoint_dir, "dense.npy"), kernel_arr)
  # assert kernel_arr.shape == (Cin, Cout)
  # dense_state = nn.Dense(features=Cout, use_bias=False)
  # dense_state.init(step_rng, jnp.ones(config.data.data_shape[:-1] + (config.data.data_shape[-1] + config.model.aug_dim,)))
  # dense_fn = functools.partial(dense_state.apply, {'params': {'kernel': kernel_arr}}) # usage: y1 = lambda1 * y0 + lambda2 * dense_fn(x0 + x1)


#*=======================================================================#