# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ImageNet input pipeline."""
import collections
import functools
import itertools
import math
import multiprocessing.pool

from absl import logging
from big_vision.datasets import sequence_packing
import big_vision.datasets.core as ds_core
import big_vision.pp.builder as pp_builder
import big_vision.utils as u
import einops
import jax
import numpy as np
import tensorflow as tf


DEFAULT_NUM_PARALLEL_CALLS = 100


def make_for_train(
    data, preprocess_fn, batch_size,
    shuffle_buffer_size=None, cache_raw=False,
    num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=2,
    *,
    pre_filter_fn=None, post_filter_fn=None,
    pack=None, skip_errors=False,
):
  """Makes an input pipeline for training."""
  # Use data filtering at your own risk: the actual split sizes won't be known
  # in advance, so epoch-based things won't work correctly.

  data = _add_tpu_host_options(data)

  data = data.filter(pre_filter_fn) if pre_filter_fn else data
  data = data.cache() if cache_raw else data

  # First shuffle and then repeat (each with a different shuffle). This way
  # the data for one epoch is all seen before the next one is processed and
  # significantly affects the number of times each example is seen when
  # processing for small number of epochs.
  if shuffle_buffer_size:
    data = data.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)
  data = data.repeat(None)

  data = data.map(preprocess_fn, num_parallel_calls=num_parallel_calls)
  data = data.filter(post_filter_fn) if post_filter_fn else data

  data = data.ignore_errors(log_warning=True) if skip_errors else data

  data = sequence_packing.pack_dataset(data, pack) if pack else data

  # Drop remainder makes shape fully static, so we can later use it if needed.
  if batch_size:
    data = data.batch(batch_size // jax.process_count(), drop_remainder=True)
  if prefetch:  # None means autotune, but we never want that.
    data = data.prefetch(prefetch)
  return data


def training(input_config):
  """Reads the data from a single dataset, or mixes it from multiple.

  The data is read either from one or mixed from multiple datasets, depending
  on the `input_config`.

  Args:
    input_config: Configures the input pipeline. See input_pipeline_test for
      examples.

  Returns:
    A tuple containing (possibly mixed) tf.data.Dataset and a total number of
    training examples.
  """
  per_pipeline_configs = (
      "shuffle_buffer_size", "cache_raw", "num_parallel_calls",
      "pre_filter_fn", "post_filter_fn", "pack", "skip_errors")
  def config_to_kw(config):
    assert "filter_fn" not in config, "Deprecated; use `pre_filter_fn` instead."
    return {k: config[k] for k in per_pipeline_configs if k in config}

  batch_size = input_config.batch_size
  # Handle separately the common case when no mixing happens.
  if isinstance(input_config.data.get("name"), str):
    train_data = ds_core.get(**input_config.data)
    train_ds = make_for_train(
        data=train_data.get_tfdata(ordered=False),
        batch_size=batch_size,
        preprocess_fn=pp_builder.get_preprocess_fn(input_config.get("pp")),
        prefetch=input_config.get("prefetch", 2),  # Default 2 for bwd compat.
        **config_to_kw(input_config)
    )
    return train_ds, train_data.total_examples

  # A helpful error instead of silent ignore:
  for k in per_pipeline_configs:
    assert k not in input_config, f"{k} is per-dataset in multi-input."

  # Parallelize the loading of datasets when doing data mixture.
  # For larger mixes, we sometimes spend >5min when doing sequentially.
  # NOTE: functools.cache is thread-safe.
  def _make(name_and_weight):
    name, weight = name_and_weight
    dataset = input_config[name]
    train_data = ds_core.get(**dataset.data)
    dataset = make_for_train(
        data=train_data.get_tfdata(ordered=False),
        # Don't batch the data just yet, it will be done after
        # mixing the different datasets below.
        batch_size=None,
        preprocess_fn=pp_builder.get_preprocess_fn(dataset.get("pp"), name),
        prefetch=0,  # Prefetching each pipeline leads to huge OOMs.
        **config_to_kw(dataset)
    )
    if keys := input_config.get("keep_only"):
      dataset = dataset.map(lambda d, keys=keys: {k: d[k] for k in keys})
    return name, dataset, weight, train_data.total_examples

  names, datasets, weights, totals = [], [], [], []
  pool = multiprocessing.pool.ThreadPool(len(input_config.data))
  for name, dataset, weight, total in pool.map(
      # Skip weight=0 datasets as a convenient optimization in sweeps.
      _make, ((name, w) for name, w in input_config.data.items() if w)):
    names.append(name)
    datasets.append(dataset)
    weights.append(weight)
    totals.append(total)

  # Normalize the weights such that they sum up to 1.
  weights = [x / sum(weights) for x in weights]

  logging.info(
      "NOTE: Total dataset mix size: %d\nContributions:\n%s", sum(totals),
      "\n".join(f"{ds}: {n} ({w * 100:.1g}%)"
                for ds, n, w in zip(names, totals, weights))
  )

  train_ds = tf.data.Dataset.sample_from_datasets(
      datasets, weights, stop_on_empty_dataset=True)
  if input_config.get("pack"):
    train_ds = sequence_packing.pack_dataset(train_ds, input_config.get("pack"))
  train_ds = train_ds.batch(
      input_config["batch_size"] // jax.process_count(), drop_remainder=True)
  if (pf := input_config.get("prefetch", 2)):
    train_ds = train_ds.prefetch(pf)

  return train_ds, sum(totals)


# The pipeline below is used for evals in multi-{G,T}PU and multi-host settings.
# As the total number of examples may not be evenly divisible accross all
# devices, we use the `infinite tf.data padding` trick, which was suggested by
# Andreas Steiner and also implemented by him in the clu library:
# https://github.com/google/CommonLoopUtils/blob/84b777c42dfd3fb6685537138433bfeb5241a006/clu/deterministic_data.py#L304.
def make_for_inference(
    data, preprocess_fn, batch_size, num_ex_per_process,
    cache_raw=False, cache_final=False,
    num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=1,
):
  """Makes an input pipeline for inference."""

  data = _add_tpu_host_options(data)
  data = data.cache() if cache_raw else data
  data = data.map(_add_internal_fields(preprocess_fn),
                  num_parallel_calls=num_parallel_calls)
  data = data.concatenate(_get_pad_data(data))

  local_batch_size = batch_size // jax.process_count()
  # This is just like `batch`, but allows batching elements of different shapes
  # into a tf.RaggedTensor. Elements of the same fixed shape remain tf.Tensors.
  # Since we do 'infinite' padding it is safe to drop the remainder.
  data = data.ragged_batch(batch_size=local_batch_size, drop_remainder=True)

  # We need to make sure that all hosts process all data and exactly the same
  # number of batches. Below we take max per-host num examples and use it on all
  # hosts to derive the number of batches.
  num_batches = math.ceil(max(num_ex_per_process) / local_batch_size)
  data = data.take(num_batches)

  # Note we cache data after a finite number of batches is taken.
  data = data.cache() if cache_final else data
  data = data.repeat()
  data = data.prefetch(prefetch) if prefetch else data
  return data, num_batches


def _get_pad_data(data):
  def zeros_like_spec(spec):
    # For unknown/flexible dimensions (None), just use 0 instead.
    return tf.zeros([x or 0 for x in spec.shape], spec.dtype)

  zero = jax.tree.map(zeros_like_spec, data.element_spec)
  return tf.data.Dataset.from_tensors(zero).repeat()


def _add_internal_fields(pp_fn):
  """Wraps pp_fn to add _mask and _id keys."""
  # Adds internal keys, that we either, in this order of preference:
  # 1. keep from result of pp_fn,
  # 2. carry over from raw (not pp_fn'd) example, or
  # 3. add, if that makes sense.
  def _pp_fn(example):
    result = pp_fn(example)
    # _mask will be False on padded examples (see _get_pad_data).
    result.setdefault("_mask", example.get("_mask", tf.constant(True)))
    # Not all data-sources can provide an ID. Only carry-over if it can:
    if "_id" in example and "_id" not in result:
      result["_id"] = example["_id"]
    return result
  return _pp_fn


def _add_tpu_host_options(data):
  options = tf.data.Options()
  options.threading.private_threadpool_size = 48
  options.threading.max_intra_op_parallelism = 1

  # Stop a whole bunch of magic stuff that eats up all RAM:
  options.experimental_optimization.inject_prefetch = False

  return data.with_options(options)


def prefetch_iterator(it, n):
  """Runs iterator `it` ahead for `n` steps. Adapted from flax."""
  if not n:
    yield from it
    return
  queue = collections.deque()

  def enqueue(n_steps):  # Enqueues *up to* `n` elements from the iterator.
    for data in itertools.islice(it, n_steps):
      # Prefetching will parallelize any processing that happens in a different
      # thread (like `jax.device_put()`), but it will be of no use for
      # processing that happens in the same thread.
      queue.append(data)

  enqueue(n)  # Fill up the buffer.
  while queue:
    yield queue.popleft()
    enqueue(1)


def threadstart_iterator(it):
  """Starts an iterator right away in a background thread."""
  # We already want to "start" the iterator in order to start the underlying
  # dataset prefetch mechanisms, so here we get the first element. But we don't
  # want to lose it from training, so we yield that one afterwards.
  # (internal link)
  pool = multiprocessing.pool.ThreadPool(processes=1)
  first_ex_promise = pool.apply_async(lambda: next(it))

  yield first_ex_promise.get()
  yield from it


def tf_to_numpy(x):
  """Convert any TF types to numpy."""
  if isinstance(x, tf.Tensor):
    if x.dtype != tf.string:  # Dense, non-string tensor? Easy!
      return x.numpy()
    else:  # A dense string tensor? Turn into actual strings, not bytes.
      return np.vectorize(bytes.decode, otypes=[str])(x.numpy())

  # The rest deals with RaggedTensors, for two main reasons:
  # - For strings, recursively apply the above conversion
  # - For common cases (eg batch of images), return more reasonable shapes.

  # Replace all None's in the shape by a fixed number, in the (somewhat common)
  # case that they are marked ragged, but really all have the same shape.
  real_shape = list(x.shape)
  for i, s in enumerate(real_shape[1:]):
    if s is not None: continue
    rowlens = np.diff(x.nested_row_splits[i])
    if len(set(rowlens)) == 1:
      real_shape[i + 1] = rowlens[0]

  if None not in real_shape:
    return tf_to_numpy(x.flat_values).reshape(real_shape)

  # It's actually ragged, reconstruct the array from the variable length pieces.
  splits = x.row_splits.numpy()
  rows = [tf_to_numpy(x.values[splits[i]:splits[i + 1]])
          for i in range(len(splits) - 1)]
  return np.fromiter(rows, dtype=object)


# Note that the order of global devices for sharding data is important and
# should be compatible with device order used for models params, state, etc.
def start_global(
    data, global_devices, n_prefetch=1, keep_on_cpu=frozenset(), warmup=False):
  """Starts the global input pipeline."""
  def maybe_shard(name, x):
    if name in keep_on_cpu:
      return tf_to_numpy(x)
    return u.make_fsarray_from_local_slice(x, global_devices)

  it = iter(data)
  if warmup:  # actually pre-fill shuffle buffers etc.
    it = threadstart_iterator(it)

  it = (u.tree_map_with_names(maybe_shard, elem) for elem in it)
  return prefetch_iterator(it, n_prefetch)


##########################################################################
# The code below is pmap-specific and is deprecated, please switch to jit.
##########################################################################


def shard_and_put(x, shard=True, put=True):
  x = np.asarray(memoryview(x))  # No-copy conversion: http://(internal link)
  if shard:
    x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count())
  if shard and put:  # Only works for pmap (for now).
    x = jax.device_put_sharded(list(x), jax.local_devices())
  return x


def start_input_pipeline(data, n_prefetch=1, shard=True):
  fn = functools.partial(shard_and_put, shard=shard, put=n_prefetch)
  it = (jax.tree.map(fn, elem) for elem in iter(data))
  return prefetch_iterator(it, n_prefetch)


def start_ragged_input_pipeline(data, n_prefetch=1, shard=True, ragged=None):
  def maybe_shard_and_put(name, x):
    return x if name in (ragged or {}) else shard_and_put(x, shard)

  it = (u.tree_map_with_names(maybe_shard_and_put, elem) for elem in iter(data))
  return prefetch_iterator(it, n_prefetch)
