"""MLPerf™ Algorithmic Efficiency API."""

import abc
import enum
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union, NamedTuple

from absl import logging
import jax
import jax.numpy as jnp
from flax import jax_utils
from jax import lax
import optax
from torch import nn
import torch.nn.functional as F
from jax.experimental.host_callback import call

import chex
from optax._src import base
from optax._src import clipping
from optax._src import numerics
from optax._src import utils
from optax._src import wrappers


class LossType(enum.Enum):
  SOFTMAX_CROSS_ENTROPY = 0
  SIGMOID_CROSS_ENTROPY = 1
  MEAN_SQUARED_ERROR = 2
  CTC_LOSS = 3
  MEAN_ABSOLUTE_ERROR = 4


class ForwardPassMode(enum.Enum):
  TRAIN = 0
  EVAL = 1
  # ... ?


class ParameterType(enum.Enum):
  WEIGHT = 0
  BIAS = 1
  CONV_WEIGHT = 2
  BATCH_NORM = 3
  EMBEDDING = 4


# Of course, Tensor knows its shape and dtype.
# Tensor = Union[jnp.array, np.array, tf.Tensor, torch.Tensor, ...]
Tensor = Any


# Define this so that if using pytree iteration utilities, can iterate over the
# model shapes pytree without iterating over the shape tuples.
class ShapeTuple:

  def __init__(self, shape_tuple):
    self.shape_tuple = shape_tuple


Shape = Union[Tuple[int],
              Tuple[int, int],
              Tuple[int, int, int],
              Tuple[int, int, int, int],
              ShapeTuple]
ParameterShapeTree = Dict[str, Dict[str, Shape]]

# If necessary, these can be zipped together easily given they have the same
# structure, to get an iterator over pairs of leaves.
ParameterKey = str
# Dicts can be arbitrarily nested.
ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]],
                           nn.Module]
ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]]

RandomState = Any  # Union[jax.random.PRNGKey, int, bytes, ...]

OptimizerState = Union[Dict[str, Any], Tuple[Any, Any]]
Hyperparameters = Any
Timing = int
Steps = int

# BN EMAs.
ModelAuxiliaryState = Any
ModelInitState = Tuple[ParameterContainer, ModelAuxiliaryState]


class Workload(metaclass=abc.ABCMeta):

  def __init__(self, *args, **kwargs) -> None:
    del args
    del kwargs
    self._param_shapes: Optional[ParameterShapeTree] = None
    self._param_types: Optional[ParameterTypeTree] = None
    self._eval_iters: Dict[str, Iterator] = {}
    self.metrics_logger = None

  @abc.abstractmethod
  def has_reached_validation_target(self, eval_result: Dict[str,
                                                            float]) -> bool:
    """Return whether or not the workload validation goal has been reached."""

  @abc.abstractmethod
  def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
    """Return whether or not the workload test goal has been reached."""

  @abc.abstractmethod
  def _build_input_queue(
      self,
      data_rng: RandomState,
      split: str,
      data_dir: str,
      global_batch_size: int,
      cache: Optional[bool] = None,
      repeat_final_dataset: Optional[bool] = None,
      num_batches: Optional[int] = None) -> Iterator[Dict[str, Any]]:
    """Build the input queue for the workload data.

    This is the only function that is NOT allowed to be called by submitters.

    For Jax this should return an itertor over tensors of shape
    (num_devices, per_device_batch_size, ...), and for PyTorch this should
    return tensors of shape (global_batch_size, ...).

    The required keys are 'inputs' and 'targets', and in general the naming
    convention should be plural key names because the values are batches of
    examples.
    """

  def attach_metrics_logger(self, metrics_logger) -> None:
    """Attaches a metric logger to workload."""
    self.metrics_logger = metrics_logger
    return

  @property
  @abc.abstractmethod
  def validation_target_value(self) -> float:
    """The validation target value to reach."""

  @property
  @abc.abstractmethod
  def test_target_value(self) -> float:
    """The test target value to reach."""

  @property
  @abc.abstractmethod
  def loss_type(self) -> LossType:
    """The type of loss function."""

  @property
  @abc.abstractmethod
  def num_train_examples(self) -> int:
    """The size of the training set."""

  @property
  @abc.abstractmethod
  def eval_batch_size(self) -> int:
    """The batch size for evaluation."""

  @property
  @abc.abstractmethod
  def num_eval_train_examples(self) -> int:
    """The number of training examples to evaluate metrics on."""

  @property
  @abc.abstractmethod
  def num_validation_examples(self) -> int:
    """The size of the validation set."""

  @property
  @abc.abstractmethod
  def num_test_examples(self) -> int:
    """The size of the test set."""

  @property
  @abc.abstractmethod
  def train_mean(self) -> Any:
    """The mean of the training data."""

  @property
  @abc.abstractmethod
  def train_stddev(self) -> Any:
    """The stddev of the training data."""

  @property
  @abc.abstractmethod
  def max_allowed_runtime_sec(self) -> int:
    """The max allowed runtime of the workload in seconds."""

  @property
  @abc.abstractmethod
  def eval_period_time_sec(self) -> int:
    """The eval period of the workload in seconds."""

  @property
  @abc.abstractmethod
  def step_hint(self) -> int:
    """Max num steps the baseline algo was given to reach the target."""

  @property
  def param_shapes(self):
    """The shapes of the parameters in the workload model."""
    if self._param_shapes is None:
      raise ValueError(
          'This should not happen, workload.init_model_fn() should be called '
          'before workload.param_shapes!')
    return self._param_shapes

  @property
  def model_params_types(self):
    """The types of the parameters in the workload model."""
    if self._param_types is None:
      raise ValueError(
          'This should not happen, workload.init_model_fn() should be called '
          'before workload.param_types!')
    return self._param_types

  @abc.abstractmethod
  def is_output_params(self, param_key: ParameterKey) -> bool:
    """Whether a key in ParameterContainer is the output layer parameters."""

  # InitModelFn = Callable[
  #     Tuple[RandomState, Optional[float], Optional[float]],
  #     ParameterContainer]
  @abc.abstractmethod
  def init_model_fn(self,
                    rng: RandomState,
                    dropout_rate: Optional[float] = None,
                    aux_dropout_rate: Optional[float] = None) -> ModelInitState:
    """Return (initial_params, initial_model_state)."""

  # ModelFn = Callable[
  #     Tuple[
  #         ParameterContainer,
  #         Dict[str, Tensor],
  #         ModelAuxiliaryState,
  #         ForwardPassMode,
  #         RandomState,
  #         bool],
  #     Tensor]
  @abc.abstractmethod
  def model_fn(self,
               params: ParameterContainer,
               augmented_and_preprocessed_input_batch: Dict[str, Tensor],
               model_state: ModelAuxiliaryState,
               mode: ForwardPassMode,
               rng: RandomState,
               update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
    """Return logits_batch"""
    # Possible side effect of updating BN.

  def output_activation_fn(self, logits_batch: Tensor,
                           framework: str) -> Tensor:
    """Turn logits into probabilities, according to the loss_type property."""
    if framework not in ['pytorch', 'jax']:
      raise ValueError(
          f'`framework` has to be either `pytorch` or `jax`, got {framework}.')
    activation_fn = {
        LossType.MEAN_SQUARED_ERROR: lambda z: z,
        LossType.MEAN_ABSOLUTE_ERROR: lambda z: z,
    }
    is_pytorch = framework == 'pytorch'  # If False, framework == 'jax'.
    softmax_fn = (
        functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax)
    sigmoid_fn = F.sigmoid if is_pytorch else jax.nn.sigmoid
    activation_fn[LossType.SOFTMAX_CROSS_ENTROPY] = softmax_fn
    activation_fn[LossType.SIGMOID_CROSS_ENTROPY] = sigmoid_fn
    activation_fn[LossType.CTC_LOSS] = softmax_fn
    return activation_fn[self.loss_type](logits_batch)

  # LossFn = Callable[Tuple[Tensor, Tensor], Tensor]
  # Does NOT apply regularization, which is left to the submitter to do in
  # `update_params`.
  @abc.abstractmethod
  def loss_fn(
      self,
      # Dense or one-hot labels, or a tuple of (tensor, padding) for speech.
      label_batch: Union[Tuple[Tensor, Tensor], Tensor],
      logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
      mask_batch: Optional[Tensor] = None,
      label_smoothing: float = 0.0) -> Dict[str, Tensor]:  # differentiable
    """Evaluate the (masked) loss function at (label_batch, logits_batch).

    Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
    valid examples in batch, 'per_example': 1-d array of per-example losses}
    (not synced across devices).
    """

  @abc.abstractmethod
  def _eval_model_on_split(self,
                           split: str,
                           num_examples: int,
                           global_batch_size: int,
                           params: ParameterContainer,
                           model_state: ModelAuxiliaryState,
                           rng: RandomState,
                           data_dir: str,
                           global_step: int = 0) -> Dict[str, float]:
    """Evaluate the model on a given dataset split, return final scalars."""

  def eval_model(self,
                 global_batch_size: int,
                 params: ParameterContainer,
                 model_state: ModelAuxiliaryState,
                 rng: RandomState,
                 data_dir: str,
                 imagenet_v2_data_dir: Optional[str],
                 global_step: int) -> Dict[str, float]:
    """Run a full evaluation of the model."""
    logging.info('Evaluating on the training split.')
    train_metrics = self._eval_model_on_split(
        split='eval_train',
        num_examples=self.num_eval_train_examples,
        global_batch_size=global_batch_size,
        params=params,
        model_state=model_state,
        rng=rng,
        data_dir=data_dir,
        global_step=global_step)
    eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
    # We always require a validation set.
    logging.info('Evaluating on the validation split.')
    validation_metrics = self._eval_model_on_split(
        'validation',
        num_examples=self.num_validation_examples,
        global_batch_size=global_batch_size,
        params=params,
        model_state=model_state,
        rng=rng,
        data_dir=data_dir,
        global_step=global_step)
    for k, v in validation_metrics.items():
      eval_metrics['validation/' + k] = v
    eval_metrics['validation/num_examples'] = self.num_validation_examples
    # Evaluate on the test set. TODO(znado): always eval on the test set.
    try:
      if self.num_test_examples is not None:
        logging.info('Evaluating on the test split.')
        test_metrics = self._eval_model_on_split(
            'test',
            num_examples=self.num_test_examples,
            global_batch_size=global_batch_size,
            params=params,
            model_state=model_state,
            rng=rng,
            data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir,
            global_step=global_step)
        for k, v in test_metrics.items():
          eval_metrics['test/' + k] = v
        eval_metrics['test/num_examples'] = self.num_test_examples
    except NotImplementedError:
      pass

    return eval_metrics


class TrainingCompleteError(Exception):
  pass


# Training algorithm track submission functions, to be filled in by the
# submitter.

def jax_cosine_warmup(step_hint: int, hyperparameters):
  # Create learning rate schedule.
  warmup_fn = optax.linear_schedule(
      init_value=0.,
      end_value=hyperparameters.learning_rate,
      transition_steps=hyperparameters.warmup_steps)
  cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1)
  cosine_fn = optax.cosine_decay_schedule(
      init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
  schedule_fn = optax.join_schedules(
      schedules=[warmup_fn, cosine_fn],
      boundaries=[hyperparameters.warmup_steps])
  return schedule_fn


lr_schedule_fn = None


class ScaleByQlrState(NamedTuple):
  """State for the Sophia algorithm."""
  count: chex.Array
  damp: chex.Array
  mu: base.Updates
  nu: base.Updates


def update_moment(updates, moments, decay, order):
  """Compute the exponential moving average of the `order`-th moment."""
  return jax.tree_util.tree_map(
      lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)


@functools.partial(jax.jit, inline=True)
def bias_correction(moment, decay, count):
  bias_correction_ = 1 - decay**count

  # Perform division in the original precision.
  return jax.tree_util.tree_map(
      lambda t: t / bias_correction_.astype(t.dtype), moment)


def scale_by_adam(
    b1: float = 0.9,
    b2: float = 0.99,
    rho: float = 0.01,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:

  mu_dtype = utils.canonicalize_dtype(mu_dtype)

  def init_fn(params):
    mu = jax.tree_util.tree_map(  # First moment
        lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
    nu = jax.tree_util.tree_map(jnp.zeros_like, params)  # Second moment
    return ScaleByQlrState(count=jnp.zeros([], jnp.int32), 
             damp=jnp.array([1.]), mu=mu, nu=nu)

  def update_fn(updates, state, params=None):
    del params
    mu = update_moment(updates, state.mu, b1, 1)
    nu = update_moment(updates, state.nu, b2, 2)
    count = state.count + jnp.array(1, dtype=jnp.int32)
    mu_hat = bias_correction(mu, b1, count)
    nu_hat = bias_correction(nu, b2, count)
    updates = jax.tree_util.tree_map(
        lambda m, v: jnp.clip(-m / (jnp.sqrt(v + eps_root) + eps),
        a_min=-rho, a_max=rho), mu, nu)
    mu = utils.cast_tree(mu, mu_dtype)

    return updates, ScaleByQlrState(count=count, damp=state.damp, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)


def optimizer(hyperparameters: Hyperparameters, num_train_examples: int):
  opt_init_fn, opt_update_fn = scale_by_adam(
      b1=hyperparameters.beta1,
      b2=hyperparameters.beta2,
      rho=hyperparameters.rho,
      eps=hyperparameters.epsilon,)
  return opt_init_fn, opt_update_fn


InitOptimizerFn = Callable[[
    Workload,
    ParameterContainer,
    ModelAuxiliaryState,
    Hyperparameters,
    RandomState
],
                           OptimizerState]


def init_optimizer_state(workload: Workload,
                         model_params: ParameterContainer,
                         model_state: ModelAuxiliaryState,
                         hyperparameters: Hyperparameters,
                         rng: RandomState) -> OptimizerState:
  # return initial_optimizer_state
  del model_params
  del model_state
  del rng
  target_setting_step_hint = int(0.75 * workload.step_hint)
  global lr_schedule_fn
  lr_schedule_fn = jax_cosine_warmup(target_setting_step_hint, hyperparameters)
  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
                                   workload.param_shapes)
  opt_init_fn, opt_update_fn = optimizer(hyperparameters,
                                         workload.num_train_examples)
  optimizer_state = opt_init_fn(params_zeros_like)

  return jax_utils.replicate(optimizer_state), opt_update_fn

_GRAD_CLIP_EPS = 1e-6

@functools.partial(
    jax.pmap,
    axis_name='batch',
    in_axes=(None, None, 0, 0, 0, None, 0, 0, None, None),
    static_broadcasted_argnums=(0, 1),
    donate_argnums=(2, 3, 4))
def pmapped_train_step(workload,
                       opt_update_fn,
                       model_state,
                       optimizer_state,
                       current_param_container,
                       hyperparameters,
                       batch,
                       rng,
                       global_step,
                       grad_clip,):
  del grad_clip
  def _loss_fn(params):
    """loss function used for training."""
    logits, new_model_state = workload.model_fn(
        params,
        batch,
        model_state,
        ForwardPassMode.TRAIN,
        rng,
        update_batch_norm=True)
    loss_dict = workload.loss_fn(
        label_batch=batch['targets'],
        logits_batch=logits,
        mask_batch=batch.get('weights'))
    loss = loss_dict['summed'] / loss_dict['n_valid_examples']
    return loss, (new_model_state, logits)

  # Compute loss and grads
  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
  aux, grad = grad_fn(current_param_container)
  loss = aux[0]
  (loss, grad) = lax.pmean((loss, grad), axis_name='batch')

  # Compute clipped Adam update direction
  new_model_state, logits = aux[1]
  updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container)

  # Compute Jacobian vector product with update vector
  logits, jvp_adams = jax.jvp(
                                lambda params: workload.model_fn(
                                       params,
                                       batch,
                                       model_state,
                                       ForwardPassMode.TRAIN,
                                       rng,
                                       update_batch_norm=False
                                       )[0],
                                (current_param_container,),
                                (updates,),
                                has_aux=False
                                )

  def _dot_product(v1, v2):
    product_tree = jax.tree_util.tree_map(lambda x, y: x * y, v1, v2)

    return sum(jnp.sum(x) for x in jax.tree_util.tree_leaves(product_tree))

  # Compute optimal step size
  prob = jax.nn.softmax(logits, axis=-1)
  term1 = jnp.sum(prob * jnp.square(jvp_adams), axis=-1)
  term2 = jnp.square(jnp.sum(prob * jvp_adams, axis=-1))
  sum_product = term1 - term2
  mean_product = lax.pmean(jnp.mean(sum_product), axis_name='batch')
    
  damp = new_optimizer_state.damp
  denom = mean_product + _dot_product(updates, updates) * damp
  optim = _dot_product(grad, updates) / (denom + 1e-7)
  optim = jnp.clip(optim, a_min=-2e-2, a_max=1e-4)
  '''lax.cond(device_idx == 0,
           lambda _: call(lambda x: print('optim'+str(x)), new_optim),
           lambda _: None,
           operand=None)'''

  # Control flow of update damping
  def true_branch(_):

    trial_updates = jax.tree_util.tree_map(lambda x: -optim * x, updates)
    trial_updated_params = optax.apply_updates(current_param_container, trial_updates)

    next_loss, _ = _loss_fn(trial_updated_params)
    next_loss = lax.pmean(next_loss, axis_name='batch')

    real_diff = next_loss - loss
    pred_diff = 0.5 * _dot_product(grad, trial_updates)
    ratio = real_diff / (pred_diff + 1e-9)

    new_damp = jax.lax.cond(
                            ratio > 3/4,
                            lambda _: 0.75 * damp,
                            lambda _: jax.lax.cond(
                                ratio < 0.0,
                                lambda _: 2. * damp,
                                lambda _: damp,
                                operand=None
                                ),
                            operand=None
                            )
    new_damp = jnp.clip(new_damp, a_min=1e-6, a_max=100.)

    return ScaleByQlrState(count=new_optimizer_state.count, damp=new_damp,
            mu=new_optimizer_state.mu, nu=new_optimizer_state.nu)

  def false_branch(_):
    return new_optimizer_state

  updated_optimizer_state = jax.lax.cond(
                          global_step % 10 == 0,
                          true_branch,
                          false_branch,
                          operand=None
                          )

  # Update model parameters
  lr_t = lr_schedule_fn(global_step)
  wd = hyperparameters.weight_decay
  updates = jax.tree_util.tree_map(lambda x, y: -lr_t * 
                (optim * x + wd * y), updates, current_param_container)
  updated_params = optax.apply_updates(current_param_container, updates)
  return updated_optimizer_state, updated_params, new_model_state, loss


UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState]
UpdateParamsFn = Callable[[
    Workload,
    ParameterContainer,
    ParameterTypeTree,
    ModelAuxiliaryState,
    Hyperparameters,
    Dict[str, Tensor],
    LossType,
    OptimizerState,
    List[Tuple[int, float]],
    int,
    RandomState
],
                          UpdateReturn]


# Each call to this function is considered a "step".
# Can raise a TrainingCompleteError if it believes it has achieved the goal and
# wants to end the run and receive a final free eval. It will not be restarted,
# and if has not actually achieved the goal then it will be considered as not
# achieved the goal and get an infinite time score. Most submissions will likely
# wait until the next freei eval and not use this functionality.
def update_params(workload: Workload,
                  current_param_container: ParameterContainer,
                  current_params_types: ParameterTypeTree,
                  model_state: ModelAuxiliaryState,
                  hyperparameters: Hyperparameters,
                  batch: Dict[str, Tensor],
                  loss_type: LossType,
                  optimizer_state: OptimizerState,
                  eval_results: List[Tuple[int, float]],
                  global_step: int,
                  rng: RandomState) -> UpdateReturn:
  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
  del current_params_types
  del loss_type
  del eval_results

  optimizer_state, opt_update_fn = optimizer_state
  per_device_rngs = jax.random.split(rng, jax.local_device_count())
  grad_clip = None

  outputs = pmapped_train_step(workload,
                               opt_update_fn,
                               model_state,
                               optimizer_state,
                               current_param_container,
                               hyperparameters,
                               batch,
                               per_device_rngs,
                               global_step,
                               grad_clip,)
  new_optimizer_state, new_params, new_model_state, loss = outputs
  if workload.metrics_logger is not None:
    workload.metrics_logger.append_scalar_metrics(
        {
            'loss': loss[0],
        }, global_step)

  return (new_optimizer_state, opt_update_fn), new_params, new_model_state


DataSelectionFn = Callable[[
    Workload,
    Iterator[Dict[str, Any]],
    OptimizerState,
    ParameterContainer,
    LossType,
    Hyperparameters,
    int,
    RandomState
],
                           Tuple[Tensor, Tensor]]


def data_selection(workload: Workload,
                   input_queue: Iterator[Dict[str, Any]],
                   optimizer_state: OptimizerState,
                   current_param_container: ParameterContainer,
                   model_state: ModelAuxiliaryState,
                   hyperparameters: Hyperparameters,
                   global_step: int,
                   rng: RandomState) -> Dict[str, Tensor]:
  """Select data from the infinitely repeating, pre-shuffled input queue.

  Each element of the queue is a batch of training examples and labels.
  """
  del workload
  del optimizer_state
  del current_param_container
  del model_state
  del hyperparameters
  del rng

  return next(input_queue)


def get_batch_size(workload_name: str) -> int:
  """Return the global batch size to use for a given workload."""
  if workload_name == 'criteo1tb':
    return 262_144
  elif workload_name == 'fastmri':
    return 32
  elif workload_name == 'imagenet_resnet':
    return 1024
  elif workload_name == 'imagenet_vit':
    return 1024
  elif workload_name == 'librispeech_conformer':
    return 256
  elif workload_name == 'librispeech_deepspeech':
    return 256
  elif workload_name == 'ogbg':
    return 512
  elif workload_name == 'wmt':
    return 256
  elif workload_name == 'mnist':
    return 16
  elif workload_name == 'cifar':
    return 128
  else:
    raise ValueError(f'Unsupported workload name: {workload_name}.')

