import tensorflow as tf
from absl import logging
from typing import NamedTuple, Mapping, Any, Sequence
# from common import immutabledict
# from collections import OrderedDict
import ml_collections
import numpy as np
import os
import shutil
import pprint
import itertools
from clu import metric_writers, periodic_actions
from common.custom_writers import create_default_writer
import inspect
from pathlib import Path
from configs import TRAIN_COLLECTION, VAL_COLLECTION, CHECKPOINTS_DIR_NAME

from common.data_lib import get_dataset
from common import utils
from common.custom_metrics import Metrics


# See https://github.com/google/flax/blob/e18a00a3b784afaf42825574836c5fe145688d8c/examples/ogbg_molpcba/train.py
# for example training loop with CLU. Perhaps also https://github.com/google/flax/blob/517b763590262d37fbbbd56ab262785cbbdb2c40/examples/imagenet/train.py
def simple_train_eval_loop(train_eval_config, workdir, model, train_iter, val_data):
  logging.info("TF physical devices:\n%s", str(tf.config.list_physical_devices()))
  # For distributed training, may want to instantiate model within this method, by accepting create_model_fn.
  config = train_eval_config

  # Create writers for logs.
  train_dir = os.path.join(workdir, TRAIN_COLLECTION)
  # train_writer = metric_writers.create_default_writer(train_dir, collection=TRAIN_COLLECTION)
  train_writer = create_default_writer(train_dir)
  train_writer.write_hparams(config.to_dict())

  val_dir = os.path.join(workdir, VAL_COLLECTION)
  val_writer = create_default_writer(val_dir)

  checkpoint_dir = os.path.join(train_dir, CHECKPOINTS_DIR_NAME)
  if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
  logging.info("Will save checkpoints to %s", checkpoint_dir)
  checkpoint = tf.train.Checkpoint(model=model)
  max_ckpts_to_keep = train_eval_config.get("max_ckpts_to_keep", 1)
  checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir,
                                                  max_to_keep=max_ckpts_to_keep)

  if train_eval_config.get("warm_start"):
    warm_start = train_eval_config.warm_start

    def _locate_ckpt_path(warm_start):
      warm_start = Path(warm_start)
      if not warm_start.is_dir():
        raise ValueError

      # Check the provided dir.
      warm_start_dir = warm_start
      restore_ckpt_path = tf.train.latest_checkpoint(warm_start_dir)
      if restore_ckpt_path:
        return restore_ckpt_path

      # Treat it as a wu dir; check the train/checkpoints subdir.
      logging.info("No ckpt in warm_start dir; check its train subdir...")
      warm_start_dir = warm_start / TRAIN_COLLECTION / CHECKPOINTS_DIR_NAME
      restore_ckpt_path = tf.train.latest_checkpoint(warm_start_dir)
      if restore_ckpt_path:
        return restore_ckpt_path

      # Treat it as an experiment dir, and load the model with the matching wid as current run.
      logging.info("No ckpt so far; treat warm_start as experiment dir and check for matching work "
                   "unit id...")
      wid = utils.get_wid()
      assert wid is not None
      for wu_dir in warm_start.iterdir():
        if wu_dir.is_file() or "wid=" not in str(wu_dir):  # Skip things that aren't wu dirs.
          continue
        parsed_wid = utils.parse_runname(str(wu_dir), parse_numbers=False)["wid"]
        if wid == parsed_wid:
          warm_start_dir = wu_dir / TRAIN_COLLECTION / CHECKPOINTS_DIR_NAME
          break
      restore_ckpt_path = tf.train.latest_checkpoint(warm_start_dir)
      if restore_ckpt_path:
        return restore_ckpt_path

      if not restore_ckpt_path:
        raise ValueError()
      return None

    try:
      restore_ckpt_path = _locate_ckpt_path(warm_start)
      restore_status = checkpoint.restore(restore_ckpt_path)
      try:
        restore_status.assert_consumed()
      except:
        logging.warning("assert_consumed() failed...")
        restore_status.expect_partial()
      logging.info("Restored from %s", restore_ckpt_path)
    except ValueError:
      logging.warning(f"Failed to find ckpt from {warm_start}")
      checkpoint_manager.restore_or_initialize()

  else:
    checkpoint_manager.restore_or_initialize()

  initial_step = int(model.global_step.numpy())
  logging.info("Starting train eval loop at step %d.", initial_step)

  # Hooks called periodically during training.
  report_progress = periodic_actions.ReportProgress(
    num_train_steps=config.num_train_steps, writer=train_writer, every_secs=60)
  # profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
  hooks = [report_progress]

  jit_compile = None  # True could be nice, but may not work.
  train_step_fn = tf.function(model.train_step, jit_compile=jit_compile, reduce_retracing=True)

  def train_multiple_steps(data_iterator):
    # TODO: actually get this to work; see https://github.com/keras-team/keras/blob/v2.10.0/keras/engine/training.py#L1184
    data_batch = next(data_iterator)
    metrics = train_step_fn(data_batch)
    return metrics

  val_step_fn = tf.function(model.validation_step)

  def evaluate_fn(step):
    metrics_list = []
    val_size = 0
    if config.num_eval_steps is None:  # Then we will iterate through the full val_data=eval_ds (assumed finite)
      eval_iter = iter(val_data)
    else:
      eval_iter = val_data  # Assume val_data is an iterator to an infinite dataset.

    eval_steps_done = 0
    for eval_step, data_batch in enumerate(eval_iter):
      if eval_steps_done == config.num_eval_steps:
        break
      # metrics = model.validation_step(data_batch)
      metrics = val_step_fn(data_batch)
      metrics_list.append(metrics)
      val_size += data_batch.shape[0]
      eval_steps_done += 1

    metrics = Metrics.merge_metrics(metrics_list)

    val_writer.write_scalars(step, metrics.scalars_numpy)
    # val_writer.write_images(step, metrics.images_grid_np)
    val_writer.write_images(step, metrics.images_grid)
    # TODO: improve using image grid; see https://stackoverflow.com/questions/42040747/more-idiomatic-way-to-display-images-in-a-grid-with-numpy
    # val_writer.write_images(step, {'vis_reconstruction': metrics.images['reconstruction'][:4]})
    # val_metrics.append({"step": step, **metrics.scalars_float})
    logging.info("Ran validation on %d instances.", val_size)
    return None

  with metric_writers.ensure_flushes(train_writer):
    step = initial_step
    # Note: all the logging (either on train or val) at step=`n` occurs at the beginning of iteration `n`. e.g.,
    # the train metrics at step = 0 is computed using model weights that have seen 0 gradient updates, while val metrics
    # at step = 10 is computed using model weights that have seen 10 gradient updates.
    while step < config.num_train_steps:

      metrics = train_multiple_steps(train_iter)
      if step % config.log_metrics_every_steps == 0:
        train_writer.write_scalars(step, metrics.scalars_float)
      for hook in hooks:
        hook(step)

      step += 1

      if (step == config.log_metrics_every_steps) or (
          config.eval_every_steps > 0 and step % config.eval_every_steps == 0
          and step < config.num_train_steps):  # Will run final eval outside the training loop.
        logging.info("Evaluating at step %d", step)
        with report_progress.timed("eval"):
          evaluate_fn(step)

      if config.get('log_imgs_every_steps') and hasattr(model, 'create_images') and (
          step % config.log_imgs_every_steps == 0 or step + 1 == config.num_train_steps):
        imgs = model.create_images(title=f'step={step}')
        train_writer.write_images(step, imgs)

      if step % config.checkpoint_every_steps == 0 or step == config.num_train_steps:
        checkpoint_path = checkpoint_manager.save(step)
        logging.info("Saved checkpoint %s", checkpoint_path)

    logging.info("Ended training loop at step %d.", step)

    # Final validation.
    if config.eval_every_steps > 0:
      logging.info("Final eval outside of training loop.")
      with report_progress.timed("eval"):
        evaluate_fn(step)

  train_writer.close()  # Will make gif.
  val_writer.close()  # Will make gif.
  return None


def train_and_eval(config, model_cls, experiments_dir, runname):
  """

  :param config: a ml_collections.ConfigDict containing configurations. This is usually read from/
  defined in a config file.
  :param model_cls:
  :param experiments_dir:
  :param runname:
  :return:
  """
  # Ensure reproducible tf.data.
  if config.train_data_config.get('fixed_batch') or config.eval_data_config.get('fixed_batch'):
    tf.config.experimental.enable_op_determinism()

  # Set seed before initializing model.
  seed = config.train_eval_config.seed
  tf.random.set_seed(seed)
  np.random.seed(seed)

  # Init model.
  model_config = config["model_config"]
  model = model_cls(**model_config)

  # Create data.
  train_ds = get_dataset(**config.train_data_config).prefetch(tf.data.AUTOTUNE)
  train_iter = iter(train_ds)
  if config.train_data_config.get('fixed_batch'):
    batch = next(train_iter)
    train_iter = itertools.repeat(batch)

  if config.eval_data_config.get('reuse_train'):  # Useful when the true source is discrete.
    eval_ds = train_ds
    eval_iter = train_iter
  else:
    eval_ds = get_dataset(**config.eval_data_config)
    eval_iter = iter(eval_ds)
    if config.eval_data_config.get('fixed_batch'):
      batch = next(eval_iter)
      eval_iter = itertools.repeat(batch)
  if config.train_eval_config.num_eval_steps is None:  # Then use (supposed finite) full dataset.
    val_data = eval_ds
  else:
    val_data = eval_iter  # Then use infinite dataset iter.

  ##################### BEGIN: Good old bookkeeping #########################
  xid = utils.get_xid()
  # Here, each runname is associated with a different work unit (Slurm call this a 'array job task')
  # within the same experiment. We add the work unit id prefix to make it easier to warm start
  # with the matching wid later.
  wid = utils.get_wid()
  if wid is None:
    wid_prefix = ''
  else:
    wid_prefix = f'wid={wid}-'
  workdir = os.path.join(experiments_dir, xid, wid_prefix + runname)
  # e.g., 'train_xms/21965/3-mshyper-rd_lambda=0.08-latent_ch=320-base_ch=192'
  if not os.path.exists(workdir):
    os.makedirs(workdir)
  # absl logs from this point on will be saved to files in workdir.
  logging.get_absl_handler().use_absl_log_file(program_name="trainer", log_dir=workdir)

  logging.info("Using workdir:\n%s", workdir)
  logging.info("Input config:\n%s", pprint.pformat(config))

  # Save the config provided.
  with open(os.path.join(workdir, f"config.json"), "w") as f:
    f.write(config.to_json(indent=2))
  if "config_filename" in config:
    shutil.copy2(config["config_filename"], os.path.join(experiments_dir, xid, "config_script.py"))

  # Log more info.
  utils.log_run_info(workdir=workdir)
  # Write a copy of models source code.
  model_source_str = inspect.getsource(inspect.getmodule(model_cls))
  with open(os.path.join(workdir, f"models.py"), "w") as f:
    f.write(model_source_str)

  # Run my custom training loop.
  return simple_train_eval_loop(config.train_eval_config, workdir, model, train_iter, val_data)
