# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Train a learned optimizer with gradients."""
import abc
import functools
import time
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

from absl import logging
import chex
import flax
import gin
import jax
from jax import core
import jax.numpy as jnp
from learned_optimization import checkpoints
from learned_optimization import profile
from learned_optimization import summary
from learned_optimization import tree_utils
from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.optimizers import base as opt_base
from outer_trainers import truncated_step as truncated_step_mod
import numpy as onp
from typing_extensions import Protocol


PRNGKey = jnp.ndarray
MetaParams = Any
ThetaModelState = Any


@flax.struct.dataclass
class GradientLearnerState:
  theta_opt_state: Any


@flax.struct.dataclass
class OuterState:
  outer_iteration: jnp.ndarray


@flax.struct.dataclass
class WorkerWeights:
  theta: MetaParams
  theta_model_state: Any
  outer_state: Optional[OuterState]


@flax.struct.dataclass
class AggregatedGradient:
  theta_grads: Any
  theta_model_state: Any
  mean_loss: jnp.ndarray
  mean_valid_loss : jnp.ndarray


@flax.struct.dataclass
class WorkerComputeOut:
  to_put: AggregatedGradient
  unroll_states: Any
  metrics: Mapping[str, float]
  event_info: Any


@flax.struct.dataclass
class GradientEstimatorState:
  pass


@flax.struct.dataclass
class UnrollInfo:
  loss: jnp.ndarray
  iteration: jnp.ndarray
  task_param: jnp.ndarray
  is_done: jnp.ndarray


@flax.struct.dataclass
class GradientEstimatorOut:
  mean_loss: jnp.ndarray
  mean_valid_loss: jnp.ndarray
  grad: Any
  unroll_state: GradientEstimatorState
  unroll_info: Optional[UnrollInfo]


@flax.struct.dataclass
class ParameterCheckpoint:
  """State that we write out to disk for using the optimizer."""
  params: lopt_base.MetaParams
  gen_id: str
  step: int


@flax.struct.dataclass
class OptCheckpoint:
  """State that we write out to disk for training the optimizer."""
  gradient_learner_state: GradientLearnerState
  elapsed_time: Union[float, jnp.ndarray]
  total_inner_steps: int


class MetaInitializer(Protocol):
  """Protocol for objects which contain a jax init function."""

  def init(self, key: chex.PRNGKey) -> MetaParams:
    pass


@jax.jit
def _tree_mean(stack):
  return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), stack)


def _tree_mean_onp(stack):
  return jax.tree_util.tree_map(lambda x: onp.mean(x, axis=0), stack)


@functools.lru_cache(None)
def _get_theta_update_fn(theta_opt: opt_base.Optimizer):

  def update(theta_opt_state, grads, loss, key, model_state):
    with summary.summary_scope("outer_opt"):
      return theta_opt.update(
          theta_opt_state, grads, loss=loss, key=key, model_state=model_state)

  fn = summary.add_with_summary(update)
  return jax.jit(fn, static_argnames=("with_summary",))


@gin.configurable
class GradientLearner:
  """Learner is responsible for training the weights of the learned opt."""

  def __init__(
      self,
      meta_init: MetaInitializer,
      theta_opt: opt_base.Optimizer,
      init_theta_from_path: Optional[str] = None,
      init_outer_state_from_path: Optional[str] = None,
      reset_outer_iteration: bool = False,
      num_steps: Optional[int] = None,
      init_seed: Optional[int] = None,
  ):
    self._theta_opt = theta_opt
    self._meta_init = meta_init
    self._init_theta_from_path = init_theta_from_path
    self._init_outer_state_from_path = init_outer_state_from_path
    self._reset_outer_iteration = reset_outer_iteration
    self._num_steps = num_steps
    self._init_seed = init_seed

  def get_meta_params(self, state: GradientLearnerState) -> MetaParams:
    return self._theta_opt.get_params(state.theta_opt_state)

  def get_meta_model_state(self,
                           state: GradientLearnerState) -> ThetaModelState:
    return self._theta_opt.get_state(state.theta_opt_state)

  def get_state_for_worker(self, state: GradientLearnerState) -> WorkerWeights:
    return WorkerWeights(
        theta=self.get_meta_params(state),
        theta_model_state=self.get_meta_model_state(state),
        outer_state=OuterState(state.theta_opt_state.iteration))

  def init(self, key: PRNGKey) -> GradientLearnerState:
    """Initial state of the GradientLearner.

    This can be constructed from a random distribution, or loaded from a path.

    Args:
      key: jax rng key

    Returns:
      gradient_learner_state: A new initial state of the gradient learner.
    """
    if self._init_seed is not None:
      key = jax.random.PRNGKey(self._init_seed)

    theta_init = self._meta_init.init(key)
    # TODO(lmetz) hook up model state for learned optimizers
    model_state = None

    if self._init_theta_from_path:
      logging.info(  # pylint: disable=logging-fstring-interpolation
          f"Got a init from params path {self._init_theta_from_path}."
          " Using this instead of random initialization.")

      # To load a checkpoint, the state of the target object must be specified,
      # so we pass fake values here.
      fake_param_checkpoint = ParameterCheckpoint(
          params=theta_init, gen_id="", step=0)
      real_param_checkpoint = checkpoints.load_state(self._init_theta_from_path,
                                                     fake_param_checkpoint)
      theta_init = real_param_checkpoint.params

    theta_opt_state = self._theta_opt.init(
        theta_init, model_state, num_steps=self._num_steps)

    if self._init_outer_state_from_path:
      logging.info(  # pylint: disable=logging-fstring-interpolation
          f"Got a init from outer state path {self._init_outer_state_from_path}."
          " Using this instead of randomly initializing.")
      fake_checkpoint = OptCheckpoint(
          gradient_learner_state=GradientLearnerState(theta_opt_state),
          elapsed_time=0.0,
          total_inner_steps=1)
      real_checkpoint = checkpoints.load_state(self._init_outer_state_from_path,
                                               fake_checkpoint)
      theta_opt_state = real_checkpoint.gradient_learner_state.theta_opt_state
      if self._reset_outer_iteration:
        theta_opt_state = theta_opt_state.replace(iteration=0)

    return GradientLearnerState(theta_opt_state)

  def update(
      self,
      state: GradientLearnerState,
      grads_list: Sequence[AggregatedGradient],
      with_metrics: bool = False,
      key: Optional[PRNGKey] = None
  ) -> Tuple[GradientLearnerState, Mapping[str, float]]:
    """Update the state of the outer-trainer using grads_list.

    This performs one outer weight update by aggregating the gradients in
    `grads_list`.

    Args:
      state: The state of the outer-trainer.
      grads_list: A list of gradients to be aggregated and applied.
      with_metrics: To compute metrics, or not.
      key: Jax PRNGKey.

    Returns:
      next_state: The next outer-training state.
      metrics: The computed metrics from this update.
    """

    metrics = {}
    theta_opt_state = state.theta_opt_state

    with profile.Profile("stack_grad"):
      grads_stack = tree_utils.tree_zip_onp([t.theta_grads for t in grads_list])
    with profile.Profile("mean_grad"):
      grads = _tree_mean_onp(grads_stack)

    with profile.Profile("stack_state"):
      model_state_stack = tree_utils.tree_zip_onp(
          [t.theta_model_state for t in grads_list])
      next_model_state = _tree_mean_onp(model_state_stack)

    with profile.Profile("stack_loss"):
      losses = jnp.asarray([t.mean_loss for t in grads_list])
      mean_loss = jnp.mean(losses)
      min_loss = jnp.min(losses)

    fn = _get_theta_update_fn(self._theta_opt)
    key1, key2 = jax.random.split(key)
    theta_opt_state, theta_update_metrics = fn(
        theta_opt_state,
        grads,
        mean_loss,
        key1,
        next_model_state,
        sample_rng_key=key2,
        with_summary=with_metrics)
    metrics = summary.aggregate_metric_list([metrics, theta_update_metrics])

    # Create fast summaries for all steps, and slower summaries occasionally
    metrics["none||mean_loss"] = mean_loss
    metrics["none||best_of_mean_loss"] = min_loss

    if with_metrics:
      metrics["none||theta_grad_norm"] = tree_utils.tree_norm(grads)
      metrics["none||theta_grad_abs_mean"] = tree_utils.tree_mean_abs(grads)

    return GradientLearnerState(theta_opt_state), metrics  # pytype: disable=bad-return-type  # jax-ndarray


class GradientEstimator(abc.ABC):
  """Base class for classes which estimate grads (via ES, PES, or backprop)."""
  truncated_step: truncated_step_mod.TruncatedStep

  def init_worker_state(self, worker_weights: WorkerWeights,
                        key: PRNGKey) -> GradientEstimatorState:
    raise NotImplementedError()

  def compute_gradient_estimate(
      self,
      worker_weights: WorkerWeights,
      key: PRNGKey,
      state: GradientEstimatorState,
      with_summary: Optional[bool],
      datas_list: Any = None,
  ) -> Tuple[GradientEstimatorOut, Mapping[str, jnp.ndarray]]:
    raise NotImplementedError()

  def task_name(self):
    return "default_task"

  def cfg_name(self):
    return "default_cfg"

  def get_datas(self) -> Any:
    raise NotImplementedError()


@functools.partial(jax.jit, donate_argnums=(0,))
def _jit_nan_to_num(vals, replace):
  return jax.tree_util.tree_map(
      functools.partial(
          jnp.nan_to_num, nan=replace, posinf=replace, neginf=replace), vals)


def _nan_to_num(vals, replace, use_jnp=False):
  if use_jnp:
    return _jit_nan_to_num(vals, replace)
  else:
    return jax.tree_util.tree_map(onp.nan_to_num, vals)


def _tree_zeros_on_device(shapes, device):
  leaves, treedef = jax.tree_util.tree_flatten(shapes)
  return jax.tree_util.tree_unflatten(
      treedef, _tree_zeros_on_device_inner(tuple(leaves), device))


@functools.partial(jax.jit, static_argnums=(0, 1))
def _tree_zeros_on_device_inner(shapes, device):
  zero_val = lambda x: jax.device_put(jnp.asarray(0, dtype=x.dtype), device)
  return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, zero_val(x)),
                                shapes)


@gin.configurable
@profile.wrap()
def gradient_worker_compute(
    worker_weights: WorkerWeights,
    gradient_estimators: Sequence[GradientEstimator],
    unroll_states: Sequence[GradientEstimatorState],
    key: PRNGKey,
    with_metrics: bool,
    clip_nan_loss_to_value: Optional[float] = 20.0,
    extra_metrics: bool = True,
    device: Optional[jax.Device] = None,
) -> WorkerComputeOut:
  """Compute a gradient signal to meta-train with.

  This function performs unrolls for each of the unroll_states with the
  corresponding gradient_estimator. The results from each of the gradient
  estimators get's merged into a single gradient. This aggregation is done
  to save bandwidth when collecting gradients from workers.

  Args:
    worker_weights: Weights created by the GradientLearner and represent the
      current parameters and model state of the learned optimizer.
    gradient_estimators: The gradient estimators used to update the unroll state
    unroll_states: state of the gradient estimator (e.g. inner problem weights)
    key: jax rng
    with_metrics: compute with summary metrics or not
    clip_nan_loss_to_value: float, value to set nan losses to
    extra_metrics: log out additional metrics.
    device: The jax device to run the computation on

  Returns:
    worker_compute_out: The results of the computation.
      This contains a gradient estimate, the next unroll states, metrics.
      A subset of which get passed to the GradientLearner.
  """
  if device is None:
    device = jax.local_devices(0)[0]

  theta = worker_weights.theta
  theta_model_state = worker_weights.theta_model_state

  theta_shape = jax.tree_util.tree_map(
      lambda x: core.ShapedArray(x.shape, x.dtype), theta
  )
  grads_accum = _tree_zeros_on_device(theta_shape, device)

  metrics_list = []
  unroll_states_out = []
  losses = []
  valid_losses = []
  event_info = []

  assert len(gradient_estimators) == len(unroll_states)

  for si, (estimator,
           unroll_state) in enumerate(zip(gradient_estimators, unroll_states)):
    with profile.Profile(f"estimator{si}"):
      stime = time.time()
      key, rng = jax.random.split(key)

      cfg_name = estimator.cfg_name()

      logging.info(
          "compute_gradient_estimate for estimator name %s and cfg name %s",
          estimator.task_name(), estimator.cfg_name())
      with profile.Profile(f"unroll__metrics{with_metrics}"):
        estimator_out, metrics = estimator.compute_gradient_estimate(
            worker_weights, rng, unroll_state, with_summary=with_metrics)

      unroll_states_out.append(estimator_out.unroll_state)
      losses.append(estimator_out.mean_loss)
      valid_losses.append(estimator_out.mean_valid_loss)
      with profile.Profile("tree_add"):
        grads_accum = tree_utils.tree_add(grads_accum, estimator_out.grad)

      # grab a random iteration from the trajectory
      if estimator_out.unroll_info:
        idx = onp.random.randint(0, len(estimator_out.unroll_info.loss))

        def extract_one(idx, x):
          return x[idx] if x is not None else None

        fn = functools.partial(extract_one, idx)
        onp_task_params = jax.tree_util.tree_map(
            onp.asarray, estimator_out.unroll_info.task_param)
        iteration = estimator_out.unroll_info.iteration[
            idx] if estimator_out.unroll_info.iteration is not None else None
        event_info.append({
            "loss": estimator_out.unroll_info.loss[idx, :],
            "task_param": jax.tree_util.tree_map(fn, onp_task_params),
            "iteration": iteration,
            "outer_iteration": worker_weights.outer_state.outer_iteration,
        })
      else:
        logging.warn("No out specified by learner. "
                     "Not logging any events data.")

      metrics = {k: v for k, v in metrics.items()}
      if extra_metrics:
        family_name = estimator.task_name()
        cfg_name = estimator.cfg_name()
        if with_metrics:
          # Metrics don't take into account which task they are comming from.
          # Let's add additional metrics with the task name pulled out.
          with profile.Profile("metric_computation"):
            keys = list(metrics.keys())
            for k in keys:
              v = metrics[k]
              assert "||" in k, f"bad metric format? Got: {k}"
              agg, name = k.split("||")
              metrics[f"{agg}||{family_name}/{name}"] = v
              metrics[f"{agg}||{cfg_name}/{name}"] = v

            mean_abs = tree_utils.tree_mean_abs(estimator_out.grad)
            metrics[f"mean||{family_name}/grad_mean_abs"] = mean_abs
            metrics[f"mean||{cfg_name}/grad_mean_abs"] = mean_abs

            norm = tree_utils.tree_norm(estimator_out.grad)
            metrics[f"mean||{family_name}/grad_norm"] = norm
            metrics[f"mean||{cfg_name}/grad_norm"] = norm
        metrics[f"mean||{family_name}/mean_loss"] = estimator_out.mean_loss
        metrics[f"mean||{cfg_name}/mean_loss"] = estimator_out.mean_loss
        metrics[f"sample||{family_name}/time"] = time.time() - stime
        metrics[f"sample||{cfg_name}/time"] = time.time() - stime

      metrics_list.append(metrics)

  with profile.Profile("mean_grads"):
    grads_accum = tree_utils.tree_div(grads_accum, len(gradient_estimators))
    mean_loss = jnp.mean(jnp.asarray(losses))
    mean_valid_loss = jnp.mean(jnp.asarray(valid_losses))

  # block here to better account for costs with profile profiling.
  with profile.Profile("blocking"):
    stime = time.time()
    mean_loss.block_until_ready()
    mean_valid_loss.block_until_ready()
    block_time = time.time() - stime

  with profile.Profile("summary_aggregation"):
    metrics = summary.aggregate_metric_list(metrics_list)
  metrics["mean||block_time"] = block_time

  with profile.Profile("strip_nan"):
    # this should ideally never be NAN
    # TODO(lmetz) check if we need these checks.
    grads_accum = _nan_to_num(grads_accum, 0.0, use_jnp=True)
    if clip_nan_loss_to_value:
      mean_loss = _nan_to_num(mean_loss, clip_nan_loss_to_value, use_jnp=True)
      mean_valid_loss = _nan_to_num(mean_valid_loss, clip_nan_loss_to_value, use_jnp=True)

  with profile.Profile("grads_to_onp"):
    to_put = AggregatedGradient(
        theta_grads=grads_accum,
        theta_model_state=theta_model_state,
        mean_loss=mean_loss,
        mean_valid_loss = mean_valid_loss)

    return WorkerComputeOut(
        to_put=jax.tree_util.tree_map(onp.asarray, to_put),
        unroll_states=unroll_states_out,
        metrics=metrics,
        event_info=event_info)


@flax.struct.dataclass
class SingleMachineState:
  gradient_learner_state: GradientLearnerState
  gradient_estimator_states: Sequence[GradientEstimatorState]


class SingleMachineGradientLearner:
  """Train with gradient estimators on a single machine.

  This is a convience wrapper calling the multi-worker interface -- namley
  both `GradientLearner` and `gradient_worker_compute`.
  """

  def __init__(self,
               meta_init: MetaInitializer,
               gradient_estimators: Sequence[GradientEstimator],
               theta_opt: opt_base.Optimizer,
               num_steps: Optional[int] = None):
    """Initializer.

    Args:
      meta_init: Class containing an init function to construct outer params.
      gradient_estimators: Sequence of gradient estimators used to calculate
        gradients.
      theta_opt: The optimizer used to train the weights of the learned opt.
      num_steps: Number of meta-training steps used by optimizer for schedules.
    """
    self.gradient_learner = GradientLearner(
        meta_init, theta_opt, num_steps=num_steps)
    self.gradient_estimators = gradient_estimators

  def init(self, key: PRNGKey) -> SingleMachineState:
    """Initial state.

    This initializes the learned optimizer weights randomly, and set's up
    optimizer variables for these weights. Additionally the first state of the
    gradient estimators is also initialized.

    Args:
      key: jax rng

    Returns:
      The initial state
    """

    key1, key = jax.random.split(key)
    theta_state = self.gradient_learner.init(key1)
    worker_weights = self.gradient_learner.get_state_for_worker(theta_state)
    keys = jax.random.split(key, len(self.gradient_estimators))
    unroll_states = [
        grad_est.init_worker_state(worker_weights, key)
        for key, grad_est in zip(keys, self.gradient_estimators)
    ]

    return SingleMachineState(
        gradient_learner_state=theta_state,
        gradient_estimator_states=unroll_states)

  def update(
      self,
      state,
      key: PRNGKey,
      with_metrics: Optional[bool] = False
  ) -> Tuple[SingleMachineState, jnp.ndarray, Mapping[str, jnp.ndarray]]:
    """Perform one outer-update to train the learned optimizer.

    Args:
      state: State of this class
      key: jax rng
      with_metrics: To compute metrics or not

    Returns:
      state: The next state from this class
      loss: loss from the current iteration
      metrics: dictionary of metrics computed
    """
    key1, key2 = jax.random.split(key)
    worker_weights = self.gradient_learner.get_state_for_worker(
        state.gradient_learner_state)
    worker_compute_out = gradient_worker_compute(
        worker_weights,
        self.gradient_estimators,
        state.gradient_estimator_states,
        key=key1,
        with_metrics=with_metrics)

    next_gradient_estimator_states = worker_compute_out.unroll_states

    next_theta_state, metrics = self.gradient_learner.update(
        state.gradient_learner_state, [worker_compute_out.to_put],
        key=key2,
        with_metrics=with_metrics)

    metrics = summary.aggregate_metric_list(
        [worker_compute_out.metrics, metrics])

    return (SingleMachineState(
        gradient_learner_state=next_theta_state,
        gradient_estimator_states=next_gradient_estimator_states),
            worker_compute_out.to_put.mean_loss, worker_compute_out.to_put.mean_valid_loss, metrics)

  def get_meta_params(self, state: SingleMachineState) -> lopt_base.MetaParams:
    """Get the weights of the learned optimizer."""
    return self.gradient_learner.get_meta_params(state.gradient_learner_state)