"""Methods for training ResNet-50 on ImageNet using JAX."""

import dataclasses
import functools
import os
from typing import Any, Callable, Dict, Sequence, Tuple, Union, Optional, Type

from absl import logging

from clu import checkpoint
from clu import metric_writers
from clu import metrics
from clu import parameter_overview
from clu import periodic_actions
import flax
import flax.jax_utils as flax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import tensorflow as tf
import optax

from ebm_obj.lib import input_pipeline
from ebm_obj.lib import utils
from ebm_obj.lib import losses
from ebm_obj.lib import evaluation as evaluator
from ebm_obj.lib.utils import TrainState


def merge_batch_stats(replicated_state):
  """Merge model batch stats."""
  if "batch_stats" in dataclasses.asdict(replicated_state):
    if jax.tree_leaves(replicated_state.batch_stats):
      cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "x"), "x")
      return replicated_state.replace(
        batch_stats=cross_replica_mean(replicated_state.batch_stats)
      )
  else:
    return replicated_state


def cross_entropy_loss(*, logits, labels):
  logp = jax.nn.log_softmax(logits)
  loglik = jnp.take_along_axis(logp, labels[:, None], axis=1)
  return -loglik


def train_step(
  model: nn.Module,
  state: utils.TrainState,
  batch: Dict[str, jnp.ndarray],
  loss_fn: losses.LossFn,
  learning_rate_fn: Callable[[int], float],
  train_metrics_cls: Type[metrics.Collection],
  predicted_max_num_instances: int,
  ground_truth_max_num_instances: int,
  max_grad_norm: Optional[float] = None,
) -> Tuple[utils.TrainState, metrics.Collection]:
  """Perform a single training step.

  Args:
    model: Model used in training step.
    rng: Random number key
    step: Which training step we are on.
    state_vars: Accessory variables.
    opt: The optimizer to use to minimize loss_fn.
    batch: Training inputs for this step.
    loss_fn: Loss function that takes model predictions and a batch of data.
    learning_rate_fn: Function that outputs learning rate as jnp.float32 given
    step as jnp.int*.
    train_metrics_cls: The metrics collection for computing training metrics.
    predicted_max_num_instances: Maximum number of instances in prediction.
    ground_truth_max_num_instances: Maximum number of instances in ground truth,
    including background (which counts as a separate instance).
    conditioning_key: Optional string. If provided, defines the batch key to be
    used as conditioning signal for the model. Otherwise this is inferred from
    the available keys in the batch.
    max_grad_norm: Optional float, if not None, clip gradients to the specified
    maximum norm.

  Returns:
    Tuple of the updated opt, state_vars, new random number key,
    metrics update, and step + 1. Note that some of this info is stored in
    TrainState, but here it is unpacked.
  """
  logging.info("train_step(batch=%s)", batch)
  rng = state.rng
  rng, step_rng = jax.random.split(rng)
  step_rng = jax.random.fold_in(step_rng, jax.process_index())
  step_rng = jax.random.fold_in(step_rng, state.step)
  step_rng = jax.random.fold_in(step_rng, jax.lax.axis_index("batch"))

  z_rng, eps_rng, dropout_rng = jax.random.split(step_rng, num=3)
  state_vars = state.state_vars
  mutable_var_keys = list(state_vars.keys()) + ["intermediates"]

  def train_loss_fn(params, state_vars):
    outputs, mutable_vars = model.apply(
      {"params": params, **state_vars},
      batch["image"],
      mutable=mutable_var_keys,
      rngs={"z_state_init": z_rng, "eps": eps_rng, "dropout": dropout_rng},
      train=True,
    )
    # Filter intermediates, as we do not want to store them in the TrainState.
    state_vars = utils.filter_key_from_frozen_dict(
      mutable_vars, key="intermediates"
    )
    loss, loss_aux = loss_fn(outputs, batch)
    return loss, (state_vars, outputs, loss_aux)

  grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
  (loss, (state_vars, preds, loss_aux)), grad = grad_fn(
    state.optimizer.target, state_vars
  )

  # Compute average gradient across multiple workers.
  grad = jax.lax.pmean(grad, axis_name="batch")

  if max_grad_norm is not None:
    grad = utils.clip_grads(grad, max_grad_norm)

  step = state.step + 1
  lr = learning_rate_fn(step)
  new_opt = state.optimizer.apply_gradient(grad, learning_rate=lr)

  new_state = state.replace(  # pytype: disable=attribute-error
    step=step, optimizer=new_opt, state_vars=state_vars, rng=rng
  )

  # import pdb;pdb.set_trace()
  metrics_update = train_metrics_cls.gather_from_model_output(
    loss=loss,
    **loss_aux,
    predicted_segmentations=utils.remove_singleton_dim(
      preds["outputs"].get("segmentations")
    ),  # pytype: disable=attribute-error
    ground_truth_segmentations=batch.get("segmentations"),
    predicted_max_num_instances=predicted_max_num_instances,
    ground_truth_max_num_instances=ground_truth_max_num_instances,
    padding_mask=batch.get("padding_mask"),
    mask=batch.get("mask")
  )
  return new_state, metrics_update


def evaluate(
  model,
  state,
  eval_ds,
  loss_fn_eval,
  eval_metrics_cls,
  config,
  writer,
  step: int,
  num_eval_steps: int = -1,
):
  """Evaluate the model."""
  eval_metrics, eval_batch, eval_preds = evaluator.evaluate(
    model,
    state,
    eval_ds,
    loss_fn_eval,
    eval_metrics_cls,
    predicted_max_num_instances=config.num_slots,
    ground_truth_max_num_instances=config.data.max_instances + 1,  # Incl. bg.
    metrics_on_cpu=config.get("metrics_on_cpu", False),
    num_eval_steps=config.get("num_eval_steps", -1),
  )

  metrics_res = eval_metrics.compute()
  writer.write_scalars(
    step, jax.tree_map(np.array, utils.flatten_named_dicttree(metrics_res))
  )
  writer.write_images(
    step,
    jax.tree_map(
      np.array,
      utils.prepare_images_for_logging(
        config,
        eval_batch,
        eval_preds,
        n_samples=config.get("n_samples", 5),
        min_n_colors=config.get("logging_min_n_colors", 1),
      ),
    ),
  )


def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
    contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  rng = jax.random.PRNGKey(config.seed)

  # Build input pipeline.
  rng, data_rng = jax.random.split(rng)
  data_rng = jax.random.fold_in(data_rng, jax.process_index())
  train_ds, eval_ds = input_pipeline.create_datasets(config, data_rng)
  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types

  # Initialize model.
  model = utils.build_model_from_config(config.model)

  def init_model(rng):
    rng, model_rng, z_rng, eps_rng, dropout_rng = jax.random.split(rng, num=5)

    init_inputs = jnp.ones([1] + list(config.data.input_shape), jnp.float32)
    initial_vars = model.init(
      {
        "params": model_rng,
        "z_state_init": z_rng,
        "eps": eps_rng,
        "dropout": dropout_rng,
      },
      init_inputs,
    )

    # Split into state variables (e.g. for batchnorm stats) and model params.
    # Note that `pop()` on a FrozenDict performs a deep copy.
    state_vars, initial_params = initial_vars.pop(
      "params"
    )  # pytype: disable=attribute-error

    # Filter out intermediates (we don't want to store these in the TrainState).
    state_vars = utils.filter_key_from_frozen_dict(state_vars, key="intermediates")
    return state_vars, initial_params

  rng, init_rng = jax.random.split(rng)
  state_vars, initial_params = init_model(init_rng)

  # Optimizer
  learning_rate_fn = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=config.learning_rate,
    warmup_steps=config.warmup_steps,
    decay_steps=config.num_train_steps,
  )
  optimizer_def = flax.optim.Adam(learning_rate=config.learning_rate)
  parameter_overview.log_parameter_overview(
    initial_params
  )  # pytype: disable=wrong-arg-types
  optimizer = optimizer_def.create(initial_params)

  # TrainState
  state = utils.TrainState(
    step=1, optimizer=optimizer, rng=rng, state_vars=state_vars
  )

  # Metrics
  # Construct TrainMetrics and EvalMetrics, metrics collections.
  train_metrics_cls = utils.make_metrics_collection(
    "TrainMetrics", config.train_metrics_spec
  )
  eval_metrics_cls = utils.make_metrics_collection(
    "EvalMetrics", config.eval_metrics_spec
  )

  loss_fn = functools.partial(losses.compute_full_loss, loss_config=config.losses)

  # Set up checkpointing of the model and the input pipeline.
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  ckpt = checkpoint.MultihostCheckpoint(
    checkpoint_dir, {"train_iter": train_iter}, max_to_keep=2
  )
  state = ckpt.restore_or_initialize(state)
  initial_step = int(state.step) + 1

  # Distribute training.
  state = flax_utils.replicate(state)
  p_train_step = jax.pmap(
    train_step,
    axis_name="batch",
    donate_argnums=(1, 2),
    static_broadcasted_argnums=(0, 3, 4, 5, 6, 7, 8),
  )

  # Logging
  writer = metric_writers.create_default_writer(
    workdir, just_logging=jax.process_index() > 0
  )
  writer.write_hparams(utils.prepare_dict_for_logging(config.to_dict()))

  logging.info("Starting training loop at step %d.", initial_step)
  hooks = []
  report_progress = periodic_actions.ReportProgress(
    num_train_steps=config.num_train_steps, writer=writer
  )
  if jax.process_index() == 0:
    hooks += [
      report_progress,
      periodic_actions.Profile(num_profile_steps=5, logdir=workdir),
    ]
  train_metrics = None
  with metric_writers.ensure_flushes(writer):
    if config.run_eval:
      with report_progress.timed("eval"):
        evaluate(
          model,
          state,
          eval_ds,
          loss_fn,
          eval_metrics_cls,
          config,
          writer,
          step=0,
          num_eval_steps=config.num_eval_steps,
        )
      with report_progress.timed("checkpoint"):
        ckpt.save(flax.jax_utils.unreplicate(state))
      return

    for step in range(initial_step, config.num_train_steps + 1):
      # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
      # devices.
      is_last_step = step == config.num_train_steps

      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = jax.tree_map(np.asarray, next(train_iter))
        state, metrics_update = p_train_step(
          model,
          state,
          batch,
          loss_fn,
          learning_rate_fn,
          train_metrics_cls,
          config.num_slots,
          config.data.max_instances + 1,  # Incl. background.
          config.get("max_grad_norm"),
        )

        metric_update = flax_utils.unreplicate(metrics_update)
        train_metrics = (
          metric_update
          if train_metrics is None
          else train_metrics.merge(metric_update)
        )

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      if step % config.log_loss_every_steps == 0 or is_last_step:
        writer.write_scalars(step, train_metrics.compute())
        train_metrics = None

      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("eval"):
          # eval_metrics = evaluate(model, state, eval_ds, config.num_eval_steps)
          evaluate(
            model,
            state,
            eval_ds,
            loss_fn,
            eval_metrics_cls,
            config,
            writer,
            step=step,
            num_eval_steps=config.num_eval_steps,
          )
        # eval_metrics_cpu = jax.tree_map(np.array, eval_metrics.compute())
        # writer.write_scalars(step, eval_metrics_cpu)

      if step % config.checkpoint_every_steps == 0 or is_last_step:
        with report_progress.timed("checkpoint"):
          state = merge_batch_stats(state)
          ckpt.save(flax_utils.unreplicate(state))

  logging.info("Finishing training at step %d", config.num_train_steps)
