# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Distill a teacher model into a FlexiViT student.

Note this file has code that is generic enough to allow using an ensemble
of teachers. This is inherited from `proj/distill/distill.py` and the goal
to only make minimal changes in a fork of that file. However, this feature
does not really make sense for FlexiViT.
"""
# pylint: disable=consider-using-from-import
import functools
import importlib
import multiprocessing.pool
import os

from absl import app
from absl import flags
from absl import logging
import big_vision.evaluators.common as eval_common
import big_vision.evaluators.proj.distill.distance as dd
import big_vision.input_pipeline as input_pipeline
import big_vision.optax as bv_optax
import big_vision.trainers.proj.flexi.common as flexi
import big_vision.utils as u
from clu import parameter_overview
import flax
import jax
import jax.numpy as jnp
from ml_collections import config_flags
import numpy as np
import optax
import tensorflow as tf

from tensorflow.io import gfile

# pylint: disable=logging-fstring-interpolation


config_flags.DEFINE_config_file(
    "config", None, "Training configuration.", lock_config=True)

flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
flags.DEFINE_boolean("cleanup", default=False,
                     help="Delete workdir (only) after successful completion.")

# Adds jax flags to the program.
jax.config.parse_flags_with_absl()


def getfirst(d, *keys):
  """Returns the first of `keys` that's present in mapping `d`."""
  result, found = None, False
  for k in reversed(keys):
    if k in d:
      result, found = d[k], True
  if found:
    return result
  else:
    raise KeyError(f"None of {keys} is in {d.keys()}")


def main(argv):
  del argv
  tf.config.experimental.set_visible_devices([], "GPU")

  config = flags.FLAGS.config
  workdir = flags.FLAGS.workdir
  logging.info(
      f"\u001b[33mHello from process {jax.process_index()} holding "
      f"{jax.local_device_count()}/{jax.device_count()} devices and "
      f"writing to workdir {workdir}.\u001b[0m")

  save_ckpt_path = None
  if workdir:  # Always create if requested, even if we may not write into it.
    gfile.makedirs(workdir)
    save_ckpt_path = os.path.join(workdir, "checkpoint.npz")

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # Here we register preprocessing ops from modules listed on `pp_modules`.
  for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]):
    importlib.import_module(f"big_vision.pp.{m}")

  # This seed makes the Jax part of things (like model init) deterministic.
  # However, full training still won't be deterministic, for example due to the
  # tf.data pipeline not being deterministic even if we would set TF seed.
  # See (internal link) for a fun read on what it takes.
  rng = jax.random.PRNGKey(config.get("seed", 0))

  # These functions do more stuff internally, for OSS release we mock them by
  # trivial alternatives in order to minize disruptions in the code.
  xid, wid = -1, -1
  fillin = lambda s: s
  def info(s, *a):
    logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
  def write_note(note):
    if jax.process_index() == 0:
      info("%s", note)

  write_note("Initializing...")

  batch_size = config.input.batch_size
  if batch_size % jax.device_count() != 0:
    raise ValueError(f"Batch size ({batch_size}) must "
                     f"be divisible by device number ({jax.device_count()})")
  info("Global batch size %d on %d hosts results in %d local batch size. With "
       "%d dev per host (%d dev total), that's a %d per-device batch size.",
       batch_size, jax.process_count(), batch_size // jax.process_count(),
       jax.local_device_count(), jax.device_count(),
       batch_size // jax.device_count())

  # First thing after above sanity checks, so we can log "start" ticks.
  mw = u.BigVisionMetricWriter(xid, wid, workdir, config)

  write_note("Initializing train dataset...")
  train_ds, ntrain_img = input_pipeline.training(config.input)

  # Start prefetching already.
  n_prefetch = config.get("prefetch_to_device", 1)
  train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch)

  total_steps = u.steps("total", config, ntrain_img, batch_size)
  def get_steps(name, default=ValueError, cfg=config):
    return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default)

  u.chrono.inform(total_steps=total_steps, global_bs=batch_size,
                  steps_per_epoch=ntrain_img / batch_size,
                  measure=mw.measure, write_note=write_note)

  info("Running for %d steps, that means %f epochs",
       total_steps, total_steps * batch_size / ntrain_img)

  # Create student and teacher models
  def get_model_mod(name):  # Used many times.
    mod_name = config[f"{name}_name"]
    return importlib.import_module(f"big_vision.models.{mod_name}")

  write_note("Initializing models...")
  def make_model(name):
    return get_model_mod(name).Model(
        num_classes=config.num_classes, **config.get(name, {}))

  models = {
      "student": make_model("student"),
      **{t: make_model(t) for t in config.teachers}
  }

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  def get_init(model, name):
    @functools.partial(jax.jit, backend="cpu")
    def _init(rng):
      bs = batch_size // jax.device_count()
      img_size = tuple(getfirst(train_ds.element_spec, name, "image").shape[1:])
      no_image = jnp.zeros((bs,) + img_size, jnp.float32)
      params = flax.core.unfreeze(model.init(rng, no_image))["params"]
      return params
    return _init

  rng, *rng_inits = jax.random.split(rng, len(models) + 1)
  with u.chrono.log_timing("z/secs/init"):
    params_cpu = {
        name: get_init(models[name], name=name)(r)
        for name, r in zip(models, rng_inits)}

  if jax.process_index() == 0:
    for name, params in params_cpu.items():
      parameter_overview.log_parameter_overview(params, msg=f"{name} params")
      mw.measure(f"num_params_{name}",
                 sum(p.size for p in jax.tree_leaves(params)))

  write_note(f"Initializing {config.optax_name} optimizer...")
  # For now, we explicitly only optimize the student parameters as there's
  # nothing else to be optimized. If we ever want to add learnable projections
  # or similar for good (we explored but ditched), need to refactor this a bit.
  tx, sched_fns = bv_optax.make(
      config, params_cpu["student"], sched_kw=dict(
          total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))

  # We jit this, such that the arrays are created on the CPU, not device[0].
  opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu["student"])
  sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]

  @jax.named_call
  def loss_fn(student_params, params, data, rngs, **flexi_kw):
    # Note: need to extract and use `student_params` out of `params` because the
    # first argument of `loss_fn` is what's differentiated wrt.
    params["student"] = student_params

    def fwd(name, params):
      return jax.named_call(models[name].apply, name=name)(
          {"params": params}, getfirst(data, name, "image"),
          train=name == "student", rngs=rngs.get(name),
          **(flexi_kw if name == "student" else {})
      )[0]  # logits, unused_outputs
    logits = {name: fwd(name, w) for name, w in params.items()}

    measurements = {}
    for name, lg in logits.items():
      measurements[f"entropy_{name}"] = -jnp.sum(
          jax.nn.log_softmax(lg) * jax.nn.softmax(lg), axis=-1)
      if "labels" in data:
        measurements[f"task_loss_{name}"] = u.softmax_xent(
            logits=lg, labels=data["labels"], reduction=False)

    # NOTE: xent is linear in labels, so for KL, this is actually the same as
    # using a teacher-ensemble in probs-space!
    measurements["distill_loss"] = 0.0
    for name in config.teachers:
      l = dd.dist(logits["student"], logits[name], config.get("distance", "kl"),
                  **config.get("distance_kw", {}))
      measurements[f"distill_loss_{name}"] = l
      measurements["distill_loss"] += l

    outputs = (measurements["distill_loss"], measurements)
    return jax.tree_map(jnp.mean, outputs)

  flexi_argnames = sorted(config.flexi)

  @functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1),
                     static_broadcasted_argnums=tuple(range(4, 4 + len(flexi_argnames))))
  def update_fn(params, opt, rng, data, *args):
    """Update step."""

    # Mixup. Note: overwrites the `data` entries (that's intended).
    if config.get("mixup") and config.mixup.p:
      to_mix = {name: data[name]
                for name in ("image", "labels") + tuple(models) if name in data}
      rng, _, to_mix = u.mixup(rng, **config.mixup, **to_mix)
      data = {**data, **to_mix}

    # Get device-specific loss rng.
    rng, *rng_models = jax.random.split(rng, len(models) + 1)
    rngs_models_local = {
        name: {"dropout": jax.random.fold_in(rngi, jax.lax.axis_index("batch"))}
        for name, rngi in zip(models, rng_models)
    }

    w = params["student"]  # Need to explicitly pull out the optimized ones.
    (l, measurements), grads = jax.lax.pmean(
        jax.value_and_grad(loss_fn, has_aux=True)(
            w, params, data, rngs=rngs_models_local,
            **dict(zip(flexi_argnames, args))),
        axis_name="batch")
    updates, opt = tx.update(grads, opt, w)
    w = optax.apply_updates(w, updates)
    params["student"] = w

    # Take some logging measurements
    gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
    measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
    ps = jax.tree_leaves(w)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
    us = jax.tree_leaves(updates)
    measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))

    return params, opt, rng, l, measurements

  # We always load the teachers first, because they NEED to be initialized
  # and since we don't ever modify them, we don't store them in checkpoints.
  for name in config.teachers:
    init_def = config[f"{name}_init"]
    write_note(f"Initializing {name} from {init_def}…")
    params_cpu[name] = get_model_mod(name).load(
        params_cpu[name], init_def, config[name],
        **config.get(f"{name}_load", {}))

  # Decide how to initialize training. The order is important.
  # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
  # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  # 3. Initialize student from something, e.g. start a fine-tuning job.
  # 4. Train from scratch.
  resume_ckpt_path = None
  if save_ckpt_path and gfile.exists(save_ckpt_path):
    resume_ckpt_path = save_ckpt_path
  elif config.get("resume"):
    resume_ckpt_path = fillin(config.resume)
  if resume_ckpt_path:
    write_note("Resume training from checkpoint...")
    # NOTE: we never change the teachers, so only checkpoint student here.
    checkpoint = {
        "params": params_cpu["student"],
        "opt": opt_cpu,
        "chrono": u.chrono.save(),
    }
    checkpoint_tree = jax.tree_structure(checkpoint)
    loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree)
    # bfloat16 type gets lost when data is saved to disk, so we recover it.
    checkpoint = jax.tree_map(u.recover_dtype, loaded)
    params_cpu["student"], opt_cpu = checkpoint["params"], checkpoint["opt"]
    u.chrono.load(checkpoint["chrono"])
  elif config.get("student_init"):
    write_note(f"Initialize student from {config.student_init}...")
    params_cpu["student"] = get_model_mod("student").load(
        params_cpu["student"], config.student_init, config.get("student"),
        **config.get("student_load", {}))
    if jax.process_index() == 0:
      parameter_overview.log_parameter_overview(
          params_cpu["student"], msg="restored (student) params")

  write_note("Kicking off misc stuff...")
  first_step = bv_optax.get_count(opt_cpu)
  u.chrono.inform(first_step=first_step)
  prof = None  # Keeps track of start/stop of profiler state.

  write_note(f"Replicating...\n{u.chrono.note}")
  params_repl = flax.jax_utils.replicate(params_cpu)
  opt_repl = flax.jax_utils.replicate(opt_cpu)

  # Define predict functions that the evaluators can use:
  def predict_fn(params, *, name, **kw):
    image = kw.pop(name, kw.pop("image", None))
    # Ugly API compatibility necessity:
    for k in ("student", *config.teachers):
      kw.pop(k, 0)
    return models[name].apply({"params": params[name]}, image, **kw)

  # 1. One for each variant of the student
  student_pfns = flexi.mkpredictfns(
      functools.partial(predict_fn, name="student"), config.flexi, "student_{x}"
  )
  # 2. One per teacher model
  teacher_pfns = {
      name: functools.partial(predict_fn, name=name)
      for name in config.teachers
  }
  # 3. One for each (student-variant, teacher) pair, eg for distance eval.
  combined_pfns = {
      f"{sn}_{tn}": lambda *a, sfn=sfn, tfn=tfn, **kw: (sfn(*a, **kw), tfn(*a, **kw))  # pylint: disable=line-too-long
      for sn, sfn in student_pfns.items()
      for tn, tfn in teacher_pfns.items()
  }

  predict_fns = {**student_pfns, **teacher_pfns, **combined_pfns}

  @functools.cache
  def evaluators():
    return eval_common.from_config(
        config, predict_fns,
        lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"),
        lambda key, cfg: get_steps(key, default=None, cfg=cfg),
    )

  rng, rng_loop = jax.random.split(rng, 2)
  rngs_loop = flax.jax_utils.replicate(rng_loop)
  ckpt_writer = None

  write_note(f"First step compilations...\n{u.chrono.note}")

  # Note that training can be pre-empted during the final evaluation (i.e.
  # just after the final checkpoint has been written to disc), in which case we
  # want to run the evals.
  if first_step in (total_steps, 0):
    mw.step_start(first_step)
    for (name, evaluator, _, prefix) in evaluators():
      if config.evals[name].get("skip_first") and first_step != total_steps:
        continue
      write_note(f"{name} evaluation...\n{u.chrono.note}")
      with u.chrono.log_timing(f"z/secs/eval/{name}"):
        for key, value in evaluator.run(params_repl):
          mw.measure(f"{prefix}{key}", value)

  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter):
    mw.step_start(step)

    np_rng = flexi.mkrng(xid, wid, step)
    flexi_args = [
        flexi.choice(config.flexi[n].v, config.flexi[n].p, np_rng)
        for n in flexi_argnames
    ]

    with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
      with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1):
        params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn(
            params_repl, opt_repl, rngs_loop, batch, *flexi_args)

    # On the first host, let's always profile a handful of early steps.
    if jax.process_index() == 0:
      prof = u.startstop_prof(prof, step, first_step, get_steps("log_training"))

    # Report training progress
    if (u.itstime(step, get_steps("log_training"), total_steps, host=0)
        or u.chrono.warmup and jax.process_index() == 0):
      for i, sched_fn_cpu in enumerate(sched_fns_cpu):
        mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1))
      l = mw.measure("training_loss", loss_value[0])
      for name, value in measurements.items():
        mw.measure(name, value[0])
      u.chrono.tick(step)
      if not np.isfinite(l):
        raise RuntimeError(f"The loss became nan or inf somewhere within steps "
                           f"[{step - get_steps('log_training')}, {step}]")

    # Checkpoint saving
    if (save_ckpt_path and
        (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or
         u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))):
      u.chrono.pause(wait_for=(params_repl["student"], opt_repl))
      u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1))
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see (internal link)). Also, takes device 0's params only.
      params_cpu["student"], opt_cpu = jax.tree_map(
          lambda x: np.array(x[0]), (params_repl["student"], opt_repl))

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if u.itstime(step, get_steps("keep_ckpt", None), total_steps):
        copy_step = step

      ckpt = {"params": params_cpu["student"],
              "opt": opt_cpu,
              "chrono": u.chrono.save()}
      ckpt_writer = pool.apply_async(
          u.save_checkpoint, (ckpt, save_ckpt_path, copy_step))
      u.chrono.resume()

    for (name, evaluator, log_steps, prefix) in evaluators():
      if u.itstime(step, log_steps, total_steps, first=False, last=True):
        u.chrono.pause(wait_for=params_repl)
        u.chrono.tick(step)  # Record things like epoch number, core hours etc.
        write_note(f"{name} evaluation...\n{u.chrono.note}")
        with u.chrono.log_timing(f"z/secs/eval/{name}"):
          for key, value in evaluator.run(params_repl):
            mw.measure(f"{prefix}{key}", value)
        u.chrono.resume()
    mw.step_end()

  # Always give a chance to stop the profiler, no matter how things ended.
  # TODO: can we also do this when dying of an exception like OOM?
  if jax.process_index() == 0 and prof is not None:
    u.startstop_prof(prof)

  # Last note needs to happen before the pool's closed =)
  write_note(f"Done!\n{u.chrono.note}")

  pool.close()
  pool.join()
  mw.close()

  # Make sure all hosts stay up until the end of main.
  u.sync()

  u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info)


if __name__ == "__main__":
  app.run(main)
