"""MLPerf™ Algorithmic Efficiency API."""

import abc
import enum
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

from absl import logging
import jax
from jax import lax
import jax.numpy as jnp
import optax
from flax import jax_utils
from torch import nn
import torch.nn.functional as F
from jax.experimental.host_callback import call


class LossType(enum.Enum):
  SOFTMAX_CROSS_ENTROPY = 0
  SIGMOID_CROSS_ENTROPY = 1
  MEAN_SQUARED_ERROR = 2
  CTC_LOSS = 3
  MEAN_ABSOLUTE_ERROR = 4


class ForwardPassMode(enum.Enum):
  TRAIN = 0
  EVAL = 1
  # ... ?


class ParameterType(enum.Enum):
  WEIGHT = 0
  BIAS = 1
  CONV_WEIGHT = 2
  BATCH_NORM = 3
  EMBEDDING = 4


# Of course, Tensor knows its shape and dtype.
# Tensor = Union[jnp.array, np.array, tf.Tensor, torch.Tensor, ...]
Tensor = Any


# Define this so that if using pytree iteration utilities, can iterate over the
# model shapes pytree without iterating over the shape tuples.
class ShapeTuple:

  def __init__(self, shape_tuple):
    self.shape_tuple = shape_tuple


Shape = Union[Tuple[int],
              Tuple[int, int],
              Tuple[int, int, int],
              Tuple[int, int, int, int],
              ShapeTuple]
ParameterShapeTree = Dict[str, Dict[str, Shape]]

# If necessary, these can be zipped together easily given they have the same
# structure, to get an iterator over pairs of leaves.
ParameterKey = str
# Dicts can be arbitrarily nested.
ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]],
                           nn.Module]
ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]]

RandomState = Any  # Union[jax.random.PRNGKey, int, bytes, ...]

OptimizerState = Union[Dict[str, Any], Tuple[Any, Any]]
Hyperparameters = Any
Timing = int
Steps = int

# BN EMAs.
ModelAuxiliaryState = Any
ModelInitState = Tuple[ParameterContainer, ModelAuxiliaryState]


class Workload(metaclass=abc.ABCMeta):

  def __init__(self, *args, **kwargs) -> None:
    del args
    del kwargs
    self._param_shapes: Optional[ParameterShapeTree] = None
    self._param_types: Optional[ParameterTypeTree] = None
    self._eval_iters: Dict[str, Iterator] = {}
    self.metrics_logger = None

  @abc.abstractmethod
  def has_reached_validation_target(self, eval_result: Dict[str,
                                                            float]) -> bool:
    """Return whether or not the workload validation goal has been reached."""

  @abc.abstractmethod
  def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
    """Return whether or not the workload test goal has been reached."""

  @abc.abstractmethod
  def _build_input_queue(
      self,
      data_rng: RandomState,
      split: str,
      data_dir: str,
      global_batch_size: int,
      cache: Optional[bool] = None,
      repeat_final_dataset: Optional[bool] = None,
      num_batches: Optional[int] = None) -> Iterator[Dict[str, Any]]:
    """Build the input queue for the workload data.

    This is the only function that is NOT allowed to be called by submitters.

    For Jax this should return an itertor over tensors of shape
    (num_devices, per_device_batch_size, ...), and for PyTorch this should
    return tensors of shape (global_batch_size, ...).

    The required keys are 'inputs' and 'targets', and in general the naming
    convention should be plural key names because the values are batches of
    examples.
    """

  def attach_metrics_logger(self, metrics_logger) -> None:
    """Attaches a metric logger to workload."""
    self.metrics_logger = metrics_logger
    return

  @property
  @abc.abstractmethod
  def validation_target_value(self) -> float:
    """The validation target value to reach."""

  @property
  @abc.abstractmethod
  def test_target_value(self) -> float:
    """The test target value to reach."""

  @property
  @abc.abstractmethod
  def loss_type(self) -> LossType:
    """The type of loss function."""

  @property
  @abc.abstractmethod
  def num_train_examples(self) -> int:
    """The size of the training set."""

  @property
  @abc.abstractmethod
  def eval_batch_size(self) -> int:
    """The batch size for evaluation."""

  @property
  @abc.abstractmethod
  def num_eval_train_examples(self) -> int:
    """The number of training examples to evaluate metrics on."""

  @property
  @abc.abstractmethod
  def num_validation_examples(self) -> int:
    """The size of the validation set."""

  @property
  @abc.abstractmethod
  def num_test_examples(self) -> int:
    """The size of the test set."""

  @property
  @abc.abstractmethod
  def train_mean(self) -> Any:
    """The mean of the training data."""

  @property
  @abc.abstractmethod
  def train_stddev(self) -> Any:
    """The stddev of the training data."""

  @property
  @abc.abstractmethod
  def max_allowed_runtime_sec(self) -> int:
    """The max allowed runtime of the workload in seconds."""

  @property
  @abc.abstractmethod
  def eval_period_time_sec(self) -> int:
    """The eval period of the workload in seconds."""

  @property
  @abc.abstractmethod
  def step_hint(self) -> int:
    """Max num steps the baseline algo was given to reach the target."""

  @property
  def param_shapes(self):
    """The shapes of the parameters in the workload model."""
    if self._param_shapes is None:
      raise ValueError(
          'This should not happen, workload.init_model_fn() should be called '
          'before workload.param_shapes!')
    return self._param_shapes

  @property
  def model_params_types(self):
    """The types of the parameters in the workload model."""
    if self._param_types is None:
      raise ValueError(
          'This should not happen, workload.init_model_fn() should be called '
          'before workload.param_types!')
    return self._param_types

  @abc.abstractmethod
  def is_output_params(self, param_key: ParameterKey) -> bool:
    """Whether a key in ParameterContainer is the output layer parameters."""

  # InitModelFn = Callable[
  #     Tuple[RandomState, Optional[float], Optional[float]],
  #     ParameterContainer]
  @abc.abstractmethod
  def init_model_fn(self,
                    rng: RandomState,
                    dropout_rate: Optional[float] = None,
                    aux_dropout_rate: Optional[float] = None) -> ModelInitState:
    """Return (initial_params, initial_model_state)."""

  # ModelFn = Callable[
  #     Tuple[
  #         ParameterContainer,
  #         Dict[str, Tensor],
  #         ModelAuxiliaryState,
  #         ForwardPassMode,
  #         RandomState,
  #         bool],
  #     Tensor]
  @abc.abstractmethod
  def model_fn(self,
               params: ParameterContainer,
               augmented_and_preprocessed_input_batch: Dict[str, Tensor],
               model_state: ModelAuxiliaryState,
               mode: ForwardPassMode,
               rng: RandomState,
               update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
    """Return logits_batch"""
    # Possible side effect of updating BN.

  def output_activation_fn(self, logits_batch: Tensor,
                           framework: str) -> Tensor:
    """Turn logits into probabilities, according to the loss_type property."""
    if framework not in ['pytorch', 'jax']:
      raise ValueError(
          f'`framework` has to be either `pytorch` or `jax`, got {framework}.')
    activation_fn = {
        LossType.MEAN_SQUARED_ERROR: lambda z: z,
        LossType.MEAN_ABSOLUTE_ERROR: lambda z: z,
    }
    is_pytorch = framework == 'pytorch'  # If False, framework == 'jax'.
    softmax_fn = (
        functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax)
    sigmoid_fn = F.sigmoid if is_pytorch else jax.nn.sigmoid
    activation_fn[LossType.SOFTMAX_CROSS_ENTROPY] = softmax_fn
    activation_fn[LossType.SIGMOID_CROSS_ENTROPY] = sigmoid_fn
    activation_fn[LossType.CTC_LOSS] = softmax_fn
    return activation_fn[self.loss_type](logits_batch)

  # LossFn = Callable[Tuple[Tensor, Tensor], Tensor]
  # Does NOT apply regularization, which is left to the submitter to do in
  # `update_params`.
  @abc.abstractmethod
  def loss_fn(
      self,
      # Dense or one-hot labels, or a tuple of (tensor, padding) for speech.
      label_batch: Union[Tuple[Tensor, Tensor], Tensor],
      logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
      mask_batch: Optional[Tensor] = None,
      label_smoothing: float = 0.0) -> Dict[str, Tensor]:  # differentiable
    """Evaluate the (masked) loss function at (label_batch, logits_batch).

    Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
    valid examples in batch, 'per_example': 1-d array of per-example losses}
    (not synced across devices).
    """

  @abc.abstractmethod
  def _eval_model_on_split(self,
                           split: str,
                           num_examples: int,
                           global_batch_size: int,
                           params: ParameterContainer,
                           model_state: ModelAuxiliaryState,
                           rng: RandomState,
                           data_dir: str,
                           global_step: int = 0) -> Dict[str, float]:
    """Evaluate the model on a given dataset split, return final scalars."""

  def eval_model(self,
                 global_batch_size: int,
                 params: ParameterContainer,
                 model_state: ModelAuxiliaryState,
                 rng: RandomState,
                 data_dir: str,
                 imagenet_v2_data_dir: Optional[str],
                 global_step: int) -> Dict[str, float]:
    """Run a full evaluation of the model."""
    logging.info('Evaluating on the training split.')
    train_metrics = self._eval_model_on_split(
        split='eval_train',
        num_examples=self.num_eval_train_examples,
        global_batch_size=global_batch_size,
        params=params,
        model_state=model_state,
        rng=rng,
        data_dir=data_dir,
        global_step=global_step)
    eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
    # We always require a validation set.
    logging.info('Evaluating on the validation split.')
    validation_metrics = self._eval_model_on_split(
        'validation',
        num_examples=self.num_validation_examples,
        global_batch_size=global_batch_size,
        params=params,
        model_state=model_state,
        rng=rng,
        data_dir=data_dir,
        global_step=global_step)
    for k, v in validation_metrics.items():
      eval_metrics['validation/' + k] = v
    eval_metrics['validation/num_examples'] = self.num_validation_examples
    # Evaluate on the test set. TODO(znado): always eval on the test set.
    try:
      if self.num_test_examples is not None:
        logging.info('Evaluating on the test split.')
        test_metrics = self._eval_model_on_split(
            'test',
            num_examples=self.num_test_examples,
            global_batch_size=global_batch_size,
            params=params,
            model_state=model_state,
            rng=rng,
            data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir,
            global_step=global_step)
        for k, v in test_metrics.items():
          eval_metrics['test/' + k] = v
        eval_metrics['test/num_examples'] = self.num_test_examples
    except NotImplementedError:
      pass

    return eval_metrics


class TrainingCompleteError(Exception):
  pass


# Training algorithm track submission functions, to be filled in by the
# submitter.

InitOptimizerFn = Callable[[
    Workload,
    ParameterContainer,
    ModelAuxiliaryState,
    Hyperparameters,
    RandomState
],
                           OptimizerState]


def init_optimizer_state(workload: Workload,
                         model_params: ParameterContainer,
                         model_state: ModelAuxiliaryState,
                         hyperparameters: Hyperparameters,
                         rng: RandomState) -> OptimizerState:
  # return initial_optimizer_state
  del model_params
  del model_state
  del rng

  opt_init_fn, (get_adam, get_product, update_fn) = adamk(hyperparameters,
                                                        workload.num_train_examples)
  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
                                   workload.param_shapes)
  optimizer_state = opt_init_fn(params_zeros_like)

  return jax_utils.replicate(optimizer_state), (get_adam, get_product, update_fn)


@functools.partial(
    jax.pmap,
    axis_name='batch',
    in_axes=(None, None, 0, 0, 0, None, 0, 0),
    static_broadcasted_argnums=(0, 1),
    donate_argnums=(2, 3, 4))
def pmapped_train_step(workload,
                       opt_update_fn,
                       model_state,
                       optimizer_state,
                       current_param_container,
                       hyperparameters,
                       batch,
                       rng):

  def _loss_fn(params):
    """loss function used for training."""
    logits, new_model_state = workload.model_fn(
        params,
        batch,
        model_state,
        ForwardPassMode.TRAIN,
        rng,
        update_batch_norm=True)
    loss_dict = workload.loss_fn(batch['targets'], logits)
    loss = loss_dict['summed'] / loss_dict['n_valid_examples']
    weight_penalty_params = jax.tree_util.tree_leaves(params)
    weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1)
    weight_penalty = hyperparameters.l2 * 0.5 * weight_l2
    loss = loss + weight_penalty
    return loss, (new_model_state, logits) 

  get_adam, get_product, update_fn = opt_update_fn
  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
  aux, grad = grad_fn(current_param_container)
  grad = lax.pmean(grad, axis_name='batch')
  new_model_state, logits = aux[1]

  adams, optimizer_state = get_adam(grad, optimizer_state)
  logits, jvp_adams = jax.jvp(
    lambda params: workload.model_fn(
        params,
        batch,
        model_state,
        ForwardPassMode.TRAIN,
        rng,
        update_batch_norm=False
    )[0],
    (current_param_container,),
    (adams,),
    has_aux=False
  )
  logits, jvp_damps = jax.jvp(
    lambda params: workload.model_fn(
        params,
        batch,
        model_state,
        ForwardPassMode.TRAIN,
        rng,
        update_batch_norm=False
    )[0],
    (current_param_container,),
    (optimizer_state['damps'],),
    has_aux=False
  )

  #call(lambda x: print(x), logits.shape)
  #call(lambda x: print(x), jvp_adams.shape)
  '''h = 2e-5
  aparam = jax.tree_map(lambda x, y: x + h * y, current_param_container, adams)
  alogits, _ = workload.model_fn(
        aparam,
        batch,
        model_state,
        ForwardPassMode.TRAIN,
        rng,
        update_batch_norm=False)
  jvp_adams = (alogits - logits) / h'''
  #batched_get_product = jax.vmap(get_product, in_axes=(None, None, 0, 0, None))
  #product = batched_get_product(grad, adams, logits, jvp_adams, optimizer_state)
  product1, product2, product3 = get_product(grad, adams, logits, jvp_adams, jvp_damps, optimizer_state)
  product1 = jnp.mean(product1, axis=0)
  product1 = lax.pmean(product1, axis_name='batch')
  product2 = jnp.mean(product2, axis=0)
  product2 = lax.pmean(product2, axis_name='batch')
  product3 = jnp.mean(product3, axis=0)
  product3 = lax.pmean(product3, axis_name='batch')
  updates, new_optimizer_state = update_fn(grad, adams, product1, product2, product3, optimizer_state)

  updated_params = optax.apply_updates(current_param_container, updates)
  return new_model_state, new_optimizer_state, updated_params


UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState]
UpdateParamsFn = Callable[[
    Workload,
    ParameterContainer,
    ParameterTypeTree,
    ModelAuxiliaryState,
    Hyperparameters,
    Dict[str, Tensor],
    LossType,
    OptimizerState,
    List[Tuple[int, float]],
    int,
    RandomState
],
                          UpdateReturn]


# Each call to this function is considered a "step".
# Can raise a TrainingCompleteError if it believes it has achieved the goal and
# wants to end the run and receive a final free eval. It will not be restarted,
# and if has not actually achieved the goal then it will be considered as not
# achieved the goal and get an infinite time score. Most submissions will likely
# wait until the next free eval and not use this functionality.
def update_params(workload: Workload,
                  current_param_container: ParameterContainer,
                  current_params_types: ParameterTypeTree,
                  model_state: ModelAuxiliaryState,
                  hyperparameters: Hyperparameters,
                  batch: Dict[str, Tensor],
                  loss_type: LossType,
                  optimizer_state: OptimizerState,
                  eval_results: List[Tuple[int, float]],
                  global_step: int,
                  rng: RandomState) -> UpdateReturn:
  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
  del current_params_types
  del loss_type
  del eval_results

  optimizer_state, opt_update_fn = optimizer_state
  per_device_rngs = jax.random.split(rng, jax.local_device_count())

  outputs = pmapped_train_step(workload,
                               opt_update_fn,
                               model_state,
                               optimizer_state,
                               current_param_container,
                               hyperparameters,
                               batch,
                               per_device_rngs)
  new_model_state, new_optimizer_state, new_params = outputs

  return (new_optimizer_state, opt_update_fn), new_params, new_model_state


DataSelectionFn = Callable[[
    Workload,
    Iterator[Dict[str, Any]],
    OptimizerState,
    ParameterContainer,
    LossType,
    Hyperparameters,
    int,
    RandomState
],
                           Tuple[Tensor, Tensor]]


# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
def data_selection(workload: Workload,
                   input_queue: Iterator[Dict[str, Any]],
                   optimizer_state: OptimizerState,
                   current_param_container: ParameterContainer,
                   model_state: ModelAuxiliaryState,
                   hyperparameters: Hyperparameters,
                   global_step: int,
                   rng: RandomState) -> Dict[str, Tensor]:
  """Select data from the infinitely repeating, pre-shuffled input queue.

  Each element of the queue is a batch of training examples and labels.
  """
  del workload
  del optimizer_state
  del current_param_container
  del model_state
  del hyperparameters
  del global_step
  del rng
  batch = next(input_queue)
  return batch


def get_batch_size(workload_name: str) -> int:
  """Return the global batch size to use for a given workload."""
  # Return the global batch size.
  if workload_name == 'criteo1tb':
    return 262_144
  elif workload_name == 'fastmri':
    return 32
  elif workload_name == 'imagenet_resnet':
    return 1024
  elif workload_name == 'imagenet_vit':
    return 512
  elif workload_name == 'librispeech_conformer':
    return 256
  elif workload_name == 'librispeech_deepspeech':
    return 256
  elif workload_name == 'ogbg':
    return 512
  elif workload_name == 'wmt':
    return 128
  elif workload_name == 'mnist':
    return 16
  elif workload_name == 'cifar':
    return 512
  else:
    raise ValueError(f'Unsupported workload name: {workload_name}.')


def create_learning_rate_fn(hparams: Hyperparameters,
                            steps_per_epoch: int):
  """Create learning rate schedule."""
  base_learning_rate = hparams.learning_rate * \
                       get_batch_size('imagenet_vit') / 512.
  warmup_fn = optax.linear_schedule(
      init_value=0.,
      end_value=base_learning_rate,
      transition_steps=hparams.warmup_epochs * steps_per_epoch)
  cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1)
  cosine_fn = optax.cosine_decay_schedule(
      init_value=base_learning_rate,
      decay_steps=cosine_epochs * steps_per_epoch)
  schedule_fn = optax.join_schedules(
      schedules=[warmup_fn, cosine_fn],
      boundaries=[hparams.warmup_epochs * steps_per_epoch])
  return schedule_fn


def adamk(hyperparameters: Hyperparameters, num_train_examples: int):
  steps_per_epoch = num_train_examples // get_batch_size('imagenet_vit')
  learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch)
  def init_fn(params):
    """Initializes the optimizer state."""
    state = {}
    state['b1'] = hyperparameters.beta1
    state['b2'] = hyperparameters.beta2
    state['eps'] = hyperparameters.epsilon
    state['m'] = jax.tree_map(jnp.zeros_like, params)
    state['v'] = jax.tree_map(jnp.zeros_like, params)
    state['t'] = 0
    state['wd'] = hyperparameters.l2
    state['damps'] = jax.tree_map(jnp.zeros_like, params)

    return state

  def get_adam(grads, state):
    def update_m(m, grad):
      return (state['b1'] * m + (1.0 - state['b1']) * grad)

    def update_v(v, grad):
      return (state['b2'] * v + (1.0 - state['b2']) * jnp.square(grad))
    
    def update_mv(m, v):
      m_hat = m / (1 - state['b1'] ** (state['t'] + 1))
      v_hat = v / (1 - state['b2'] ** (state['t'] + 1))
      return m_hat / (jnp.sqrt(v_hat) + state['eps'])

    state['m'] = jax.tree_map(update_m, state['m'], grads)
    state['v'] = jax.tree_map(update_v, state['v'], grads)

    adams = jax.tree_map(update_mv, state['m'], state['v'])

    return adams, state

  def _dot_product(v1, v2):
      product_tree = jax.tree_map(lambda x, y: x * y, v1, v2)
      leaves = jax.tree_util.tree_leaves(product_tree)
      
      result = jnp.sum(jnp.concatenate([jnp.asarray(leaf).flatten() for leaf in leaves]))

      return result

  '''def get_product(grads, adams, logits, jvp, state):

    def _fisher_kernel_func(logits):
      prob = jax.nn.softmax(logits, axis=-1)
      diagonal = jnp.einsum('bi, ij -> bij', prob, jnp.eye(logits.shape[-1]))
      rankone = jnp.einsum('bi, bj -> bij', prob, prob)

      return diagonal - rankone

    def _vMv_product(vector1, vector2, matrix):
      return jnp.einsum('bi, bij, bj -> b', vector1, matrix, vector2)

    fisher_kernel = _fisher_kernel_func(logits)
    adam_F_adam = _vMv_product(jvp, jvp, fisher_kernel)

    return adam_F_adam

  def update_fn(grads, adams, product, state):
    lr = learning_rate_fn(state['t'])
    mat11 = product + state['wd'] * _dot_product(adams, adams)
    optim = _dot_product(grads, adams) / (mat11 + state['eps'])
    optim = jnp.clip(optim, a_min=-0.03, a_max=0.3)
    state['t'] += 1
    update = jax.tree_map(lambda x: -optim*lr*x, adams)

    return update, state'''
  def get_product(grads, adams, logits, jvp_a, jvp_d, state):
    def _fisher_kernel_func(logits):
      prob = jax.nn.softmax(logits, axis=-1)
      diagonal = jnp.einsum('bi, ij -> bij', prob, jnp.eye(logits.shape[-1]))
      rankone = jnp.einsum('bi, bj -> bij', prob, prob)

      return diagonal - rankone

    def _vMv_product(vector1, vector2, matrix):
      return jnp.einsum('bi, bij, bj -> b', vector1, matrix, vector2)
    
    fisher_kernel = _fisher_kernel_func(logits)
    adam_F_adam = _vMv_product(jvp_a, jvp_a, fisher_kernel)
    adam_F_damp = _vMv_product(jvp_a, jvp_d, fisher_kernel)
    damp_F_damp = _vMv_product(jvp_d, jvp_d, fisher_kernel)

    return adam_F_adam, adam_F_damp, damp_F_damp

  def update_fn(grads, adams, product_aa, product_ad, product_dd, state):
    lr = learning_rate_fn(state['t'])
    damps = state['damps']
    mat11 = product_aa + state['wd'] * _dot_product(adams, adams) + state['eps'] * 10
    mat12 = product_ad + state['wd'] * _dot_product(adams, damps)
    mat21 = mat12
    mat22 = product_dd + state['wd'] * _dot_product(damps, damps) + state['eps'] * 10
    
    mat = jnp.array([[mat11, mat12], [mat21, mat22]])

    vec1 = _dot_product(grads, adams)
    vec2 = _dot_product(grads, damps)
    vec = jnp.array([[vec1], [vec2]])

    optim = jnp.linalg.solve(mat, vec)
    optim = jnp.clip(optim, a_min=-0.03, a_max=0.3)
    def get_damps(adams, damps):
      return (adams * optim[0] + damps * optim[1]) * 100
    state['damps'] = jax.tree_map(get_damps, adams, damps)
    state['t'] += 1
    update = jax.tree_map(lambda x: -lr*x/100, state['damps'])
    return update, state

  return init_fn, (get_adam, get_product, update_fn)
