import os
import sys
import time

import h5py
import pickle
import hashlib
import itertools

from absl import app
from absl import flags
from absl import logging

import numpy as np
import scipy as sp
from scipy import sparse

from sklearn.metrics import roc_auc_score, average_precision_score
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.18'

import jax
import jax.numpy as jnp

from jax import random
from jax.api import jit
from jax.experimental import optimizers


## DATASET
flags.DEFINE_integer('n_arms', 100, 'No. of arms involved.')
flags.DEFINE_integer('skip_top', 2, 'No. of top-prevalence arms to skip.')
flags.DEFINE_integer('k_train', 500, 'No. of training examples per arm.')
flags.DEFINE_integer('n_valid', 40000, 'No. of evaluation points.')
flags.DEFINE_bool('add_bias', True, 'Append vector of ones to the data.')

## MODEL
flags.DEFINE_enum('model_type', 'moe', ('moe', 'dot', 'concat'), 'Model.')
flags.DEFINE_integer('emb_dim', None, 'Number of embedding features.')
flags.DEFINE_integer('feat_dim', None, 'No. of nominator/expert features.')
flags.DEFINE_bool('use_baseline_init', True, 'Initialise biases to base rates.')

## -- MoE
flags.DEFINE_integer('n_experts', 3, 'No. of experts in the MoE model.')
flags.DEFINE_bool('random_gate_init', False, 'Use random gating params init.')
flags.DEFINE_float('eval_cutoff', 0.5, 'Threshold for +1 label assignment.')
flags.DEFINE_bool('fixed_arm_pools', False, 'Fixed pools instead of gating.')
flags.DEFINE_float('likelihood_var', 1.0, 'Likelihood variance')

## OPTIMISATION
flags.DEFINE_integer('n_steps', 50000, 'Number of optimisation steps.')
flags.DEFINE_enum('optimiser', 'rmsprop', ('rmsprop', 'adam'), 'Optimiser.')
flags.DEFINE_float('lr', 1e-3, 'Base learning rate for `stoch_grad == True`.')
flags.DEFINE_integer('batch_size', 4096, 'Batch size for `stoch_grad == True`.')

# -- loss
flags.DEFINE_enum('loss_type', 'moe', ('moe', 'mse', 'bce'), 'Loss type.')
flags.DEFINE_bool('use_weighting', False, 'Weigh positive examples higher.')

# -- gradient estimation & scaling params
flags.DEFINE_float('momentum', 0.9, 'Optimiser momemtum.')
flags.DEFINE_bool('stoch_grad', True, 'Use mini-batch training.')
flags.DEFINE_float('decay_rate', 0.9, 'Optimiser learning rate decay.')
flags.DEFINE_integer('decay_steps', 2000, 'Optimiser learning rate decay freq.')

# -- early stopping
flags.DEFINE_bool('early_stopping', False, 'Should best parameters be saved.')
flags.DEFINE_string('stopping_criterion', 'ndcg@5', 'Optimised criterion.')

## AUXILIARY FLAGS TO SETUP THE RUN
flags.DEFINE_bool('move_to_device', False, 'Keep train & valid set on device.')
flags.DEFINE_string('data_dir', '../data/amazoncat-13k-dedup', 'Dataset path.')
flags.DEFINE_string('save_dir', '../results/moe', 'Path to results folder.')
flags.DEFINE_integer('eval_freq', 1000, 'Frequency of results evaluation.')
flags.DEFINE_integer('seed', 0, 'PRNG seed.')


FLAGS = flags.FLAGS
# if FLAGS.move_to_device:
#   _eval_jit = functools.partial(jit, static_argnums=(0,))
# else:
#   _eval_jit = lambda fn: fn
_evaluator_jit = lambda fn: fn  # TODO


class Evaluator(dict):
  def __init__(self, result_types, *args, **kwargs):
    super().__init__(*args, **kwargs)

    self.np = jnp if FLAGS.move_to_device else np
    # self.eval_dict = {k: [] for k in result_types}
    for rt in result_types:
      self[rt] = []

  @_evaluator_jit
  def mse(self, preds, y):
    return self.np.mean((preds - y)**2)

  @_evaluator_jit
  def rmse(self, preds, y):
    return self.np.mean((preds - y)**2) ** 0.5

  @_evaluator_jit
  def acc(self, preds, y, c=0.5):
    pl, yl = (f > c for f in (preds, y))
    return self.np.mean(pl == yl)

  @_evaluator_jit
  def instant_regret(self, preds, y):
    idx = self.np.argmax(preds, axis=-1)
    labels = self.np.take_along_axis(y, idx[..., None], axis=-1).squeeze(-1)
    return self.np.mean(self.np.max(y, axis=-1) - labels)

  @_evaluator_jit
  def prec_recall(self, preds, y, c=0.5):
    """Average precision and recall computation with *binary* labels."""

    # if not self.np.array_equal(y, y.astype(bool)):
    #   raise ValueError('target vector `y` must contain only zeros and ones')

    # argsort in *descending* order
    idx = self.np.argsort(preds, axis=-1)[..., ::-1]
    if isinstance(c, np.ndarray) and c.size > 1:
      c = self.np.take_along_axis(c, idx, axis=-1)  # class-dependent cut-offs

    yl = self.np.take_along_axis(y, idx, axis=-1)  # .astype(preds.dtype)
    pl = self.np.take_along_axis(preds, idx, axis=-1) > c
    pl = pl.astype(yl.dtype)

    _and = self.np.cumsum(yl * pl, axis=-1)
    precs = _and / self.np.maximum(1.0, pl.cumsum(-1))
    recalls = _and / self.np.maximum(1.0, yl.sum(-1, keepdims=True))

    return precs.mean(0), recalls.mean(0)

  @_evaluator_jit
  def ndcg(self, preds, y):
    idx = self.np.argsort(preds, axis=-1)[:, ::-1]
    labels = self.np.take_along_axis(y, idx, axis=-1)
    denoms = self.np.log2(2 + self.np.arange(y.shape[-1]))

    dcg = (2**labels - 1) / denoms[None]
    idcg = (2**(self.np.sort(labels, axis=-1)[..., ::-1]) - 1) / denoms[None]
    dcg, idcg = (df.cumsum(-1) for df in (dcg, idcg))

    # accounts for no-positive-label examples
    ndcg = dcg / self.np.maximum(1.0, idcg)

    return ndcg.mean(0)

  def eval_fn(self, preds_train, preds_valid, y_train, y_valid):
    if not FLAGS.move_to_device:
      preds_train, preds_valid = map(np.array, (preds_train, preds_valid))

    train_rmse = float(self.rmse(preds_train, y_train))
    valid_rmse = float(self.rmse(preds_valid, y_valid))
    train_acc = float(self.acc(preds_train, y_train, 0.5))
    valid_acc = float(self.acc(preds_valid, y_valid, 0.5))

    # following can only be evaluated on `valid` (full n x k data)
    valid_regret = float(self.instant_regret(preds_valid, y_valid))
    precs, recalls = self.prec_recall(preds_valid, y_valid, FLAGS.eval_cutoff)
    ndcgs = self.ndcg(preds_valid, y_valid)  # the most expensive one

    self['train_rmse'].append(train_rmse)
    self['val_rmse'].append(valid_rmse)
    self['train_acc'].append(train_acc)
    self['val_acc'].append(valid_acc)

    self['regret'].append(valid_regret)
    self['prec'].append(precs)
    self['recall'].append(recalls)
    self['ndcg'].append(ndcgs)

    logging.info(
      f'regret {self["regret"][-1]:.4f}  prec@5 {self["prec"][-1][4]:.4f} '
      f'recall@5 {self["recall"][-1][4]:.4f}  ndcg@5 {self["ndcg"][-1][4]:.4f}')
    # @5 is at the [4] position in the zero-based indexing


# probably shouldn't be used on the Amazon data except for low arm settings
def concat_model(n_arms, feat_dim, emb_dim, init_bias=None):
  lin_dim = feat_dim + emb_dim

  def init_fn(key):
    lin_param = jnp.zeros(lin_dim)
    emb_param = random.normal(key, shape=(n_arms, emb_dim))
    if init_bias is None:
      bias_param = jnp.zeros(n_arms)
    else:
      if init_bias.shape != (n_arms,):
        raise ValueError(
          f'`init_bias` must have shape `(n_arms,)`; was {init_bias.shape}')
      bias_param = jnp.array(init_bias)
    return lin_param, emb_param, bias_param

  def apply_fn(params, x, item_ids):
    lin_param, emb_param, bias_param = params
    item_emb, biases = emb_param[item_ids], bias_param[item_ids]
    return jnp.hstack([x, item_emb]) @ lin_param + biases

  def flatten(params):
    return jnp.concatenate([p.ravel() for p in params])

  def unflatten(pf):
    lp_dim, ep_dim = lin_dim, n_arms * emb_dim
    lin_param = pf[:lp_dim]
    emb_param = pf[lp_dim:lp_dim+ep_dim].reshape((n_arms, emb_dim))
    bias_param = pf[lp_dim+ep_dim:]  # .reshape((n_arms,))
    return lin_param, emb_param, bias_param

  return (init_fn, apply_fn), (flatten, unflatten)


def dot_model(n_arms, feat_dim, emb_dim, init_bias=None):
  def init_fn(key):
    lkey, ekey = random.split(key)
    lin_param = random.normal(lkey, shape=(feat_dim, emb_dim)) / feat_dim**0.5
    emb_param = random.normal(ekey, shape=(n_arms, emb_dim))
    if init_bias is None:
      bias_param = jnp.zeros(n_arms)
    else:
      if init_bias.shape != (n_arms,):
        raise ValueError(
          f'`init_bias` must have shape `(n_arms,)`; was {init_bias.shape}')
      bias_param = jnp.array(init_bias)
    return lin_param, emb_param, bias_param

  def apply_fn(params, x, item_ids, binary=True):
    lin_param, emb_param, bias_param = params
    user_emb = x @ lin_param
    item_emb = emb_param[item_ids]
    biases = bias_param[item_ids]
    if binary:
      ret = jnp.sum(user_emb * item_emb, axis=-1) / emb_dim**0.5 + biases
    else:
      ret = user_emb @ item_emb.T / emb_dim**0.5 + biases[None]
    return ret

  def flatten(params):
    return jnp.concatenate([p.ravel() for p in params])

  def unflatten(pf):
    lp_dim, ep_dim = feat_dim * emb_dim, n_arms * emb_dim
    lin_param = pf[:lp_dim].reshape((feat_dim, emb_dim))
    emb_param = pf[lp_dim:lp_dim+ep_dim].reshape((n_arms, emb_dim))
    bias_param = pf[lp_dim+ep_dim:]  # .reshape((n_arms,))
    return lin_param, emb_param, bias_param

  return (init_fn, apply_fn), (flatten, unflatten)


def moe_model(
    experts, n_arms, feat_dim, emb_dim, random_gate_init=False,
    arm_pools=None, eps=1e-31):

  n_experts = len(experts)
  if arm_pools is None:
    fixed_expert_weights = None
  else:
    arm_pools = jnp.array(arm_pools)
    fixed_expert_weights = 1.0 / jnp.bincount(arm_pools.ravel())

    if len(arm_pools) != n_experts:
      raise ValueError(
        f'{n_experts} experts but {len(arm_pools)} arm pools supplied')

  def init_fn(key):
    key, expert_key = random.split(key)
    expert_params = tuple(_init(skey) for (_init, _), skey in
                          zip(experts, random.split(expert_key, n_experts)))
    if arm_pools is None:
      lin_dim = feat_dim + emb_dim
      key, emb_key = random.split(key)
      emb_param = random.normal(emb_key, shape=(n_arms, emb_dim))
      if random_gate_init:
        key, lin_key, bias_key = random.split(key, 3)
        lin_param = random.normal(
          lin_key, shape=(lin_dim, n_experts)) / lin_dim**0.5
        bias_param = random.normal(bias_key, shape=(n_arms, n_experts))
      else:
        lin_param = jnp.zeros((lin_dim, n_experts))
        bias_param = jnp.zeros((n_arms, n_experts))
      gate_params = (lin_param, emb_param, bias_param)

      return expert_params, gate_params
    else:
      return expert_params

  def preds_and_logits(params, x, item_ids, binary=True):
    if arm_pools is None:
      expert_params, gate_params = params
      emb_lin, emb_emb, emb_bias = gate_params
    else:
      expert_params = params
      emb_lin, emb_emb, emb_bias = None, None, None

    expert_preds = jnp.stack([
      _apply(p, x, item_ids, binary=binary)
      for (_, _apply), p in zip(experts, expert_params)])

    if arm_pools is None:
      user_emb = x  # x @ emb_lin  # not needed in `sum` mode
      item_emb = emb_emb[item_ids]
      biases_emb = emb_bias[item_ids]

      # TODO: do not rely on first dim of x being 1 (implement biases instead)
      user_logits = (user_emb @ emb_lin[:feat_dim]).T
      item_logits = (item_emb @ emb_lin[feat_dim:] + biases_emb).T
      if binary:
        gate_logits = user_logits + item_logits
      else:
        gate_logits = user_logits[:, :, None] + item_logits[:, None]
    else:
      pws = fixed_expert_weights[item_ids]
      masks = jnp.array([pws * jnp.isin(item_ids, ap) for ap in arm_pools])
      if not binary:
        masks = masks[:, None]
      # avoid nan/inf by assigning very low log probability to out-of-pool items
      gate_logits = jnp.log(masks + eps / n_experts)

    return expert_preds, gate_logits

  def apply_fn(params, x, item_ids, binary=True):
    expert_preds, gate_logits = preds_and_logits(params, x, item_ids, binary)
    probs = jax.nn.softmax(gate_logits, axis=0)
    return jnp.sum(probs * expert_preds, axis=0)

  return (init_fn, apply_fn), preds_and_logits


# CODE USED TO SAVE THE DEDUPLICATED DATESET
# ctxts, rewards = utils.read_bandit_dataset('amazoncat-13k-bert')
# ctxts, unique_ids = np.unique(ctxts, axis=0, return_index=True)
# rewards = rewards[unique_ids]
# with h5py.File('../data/amazoncat-13k-dedup/features.h5', 'w') as ff:
#   ff.create_dataset('features', data=ctxts)
# sp.sparse.save_npz('../data/amazoncat-13k-dedup/ratings.npz', rewards)


def load_dataset(
    rng, data_dir, feat_dim, n_arms, skip_top, k_train, n_valid,
    add_bias, move_to_device):
  with h5py.File(os.path.join(data_dir, 'features.h5'), 'r') as feature_file:
    ctxts = feature_file['features'][:]
  with open(os.path.join(data_dir, 'ratings.npz'), 'rb') as ratings_file:
    rewards = sp.sparse.load_npz(ratings_file).tocsr()  # .toarray().flatten()

  # take random data subsample
  train, valid = subsample_data(
    rng, ctxts=ctxts, rewards=rewards, n_arms=n_arms, skip_top=skip_top,
    k_train=k_train, n_valid=n_valid, add_bias=add_bias)

  # take the true dimensionality of the data
  feat_dim = train[0].shape[-1] if feat_dim is None else feat_dim
  train, valid = preprocess_data(
    train=train, valid=valid, feat_dim=feat_dim, move_to_device=move_to_device)

  return train, valid, feat_dim


def subsample_data(
    rng, ctxts, rewards, n_arms, skip_top, k_train, n_valid, add_bias):
  # n_arms, skip_top = FLAGS.n_arms, FLAGS.skip_top
  # n_train, n_valid = FLAGS.k_train * n_arms, FLAGS.n_valid
  n_train = k_train * n_arms

  _orig_arms = np.argsort(np.array(rewards.mean(0))[0])[::-1]
  _orig_arms = _orig_arms[skip_top:skip_top+n_arms]

  # ensure order of arms doesn't carry information!
  _orig_arms = _orig_arms[rng.permutation(n_arms)]
  rwrd_subset = rewards[:, _orig_arms].toarray()

  train_xids = rng.randint(0, len(ctxts), size=n_train)
  train_rids = np.array([rng.permutation(n_arms) for _ in range(k_train)])
  train_rids = train_rids.ravel()
  x_train, y_train = ctxts[train_xids], rwrd_subset[(train_xids, train_rids)]

  valid_xids = np.setdiff1d(np.arange(len(ctxts)), train_xids)
  valid_xids = rng.choice(valid_xids, replace=False, size=n_valid)
  x_valid, y_valid = ctxts[valid_xids], rwrd_subset[valid_xids]

  if add_bias:
    x_train = np.hstack([np.ones((n_train, 1)), x_train])
    x_valid = np.hstack([np.ones((n_valid, 1)), x_valid])

  return (x_train, y_train, train_rids), (x_valid, y_valid)


def preprocess_data(train, valid, feat_dim, move_to_device):
  x_train, y_train, train_rids = train
  x_valid, y_valid = valid

  # normalisation
  mean, std = np.mean(x_train), np.std(x_train)
  x_train, x_valid = ((x - mean) / std for x in (x_train, x_valid))

  # eval_cutoff = baseline_rate[None]
  x_train, x_valid = (df[..., :feat_dim] for df in (x_train, x_valid))
  if move_to_device:
    x_train, y_train = map(jnp.array, (x_train, y_train))
    x_valid, y_valid = map(jnp.array, (x_valid, y_valid))
    train_rids = jnp.array(train_rids)
    # eval_cutoff = jnp.array(eval_cutoff)

  return (x_train, y_train, train_rids), (x_valid, y_valid)


def init_model_and_optim(
    key, model_type, loss_type, optimiser, emb_dim, feat_dim, n_arms, base_lr,
    decay_steps, decay_rate, init_bias, random_gate_init, neg_count, pos_count,
    n_experts, use_weighting, likelihood_var, fixed_arm_pools):
  if model_type == 'concat':
    (init_fn, apply_fn), _ = concat_model(
      n_arms, feat_dim=feat_dim, emb_dim=emb_dim, init_bias=init_bias)
  elif model_type == 'dot':
    (init_fn, apply_fn), _ = dot_model(
      n_arms, feat_dim=feat_dim, emb_dim=emb_dim, init_bias=init_bias)
  elif model_type == 'moe':
    if fixed_arm_pools:
      if n_experts > n_arms:
        raise ValueError(f'{n_experts} experts but only {n_arms} arms')

      key, pool_key = random.split(key)
      app = max(1, int(np.ceil(n_arms / n_experts)))

      arm_perm = [
        random.permutation(skey, n_arms) for skey in random.split(pool_key)]
      arm_perm = jnp.concatenate(arm_perm)[:app * n_experts]
      arm_pools = jnp.array(jnp.split(arm_perm, n_experts))
    else:
      arm_pools = None

    experts = [
      dot_model(n_arms, feat_dim, emb_dim, init_bias)[0]
      for _ in range(n_experts)]

    (init_fn, apply_fn), preds_and_logits = moe_model(
      experts, n_arms=n_arms, feat_dim=feat_dim, emb_dim=emb_dim,
      random_gate_init=random_gate_init, arm_pools=arm_pools)
  else:
    raise NotImplementedError(model_type)

  # perhaps we should use class-conditional weights?!
  if use_weighting:  # caveat: messes up scale-dependent optimisers!
    total = neg_count + pos_count
    neg_weight = 0.5 * total / neg_count
    pos_weight = 0.5 * total / pos_count
  else:
    neg_weight, pos_weight = 1.0, 1.0

  # lr = lambda t: base_lr * (decay_rate ** (t / decay_steps))
  lr = optimizers.exponential_decay(base_lr, decay_steps, decay_rate)
  if optimiser == 'rmsprop':
    opt_init, opt_update, get_params = optimizers.rmsprop(lr, 0.9, eps=1e-8)
  elif optimiser == 'adam':
    opt_init, opt_update, get_params = optimizers.adam(lr)
  else:
    raise NotImplementedError(optimiser)

  def _get_weights(y):
    ws = y == 1.0
    return ws * pos_weight + (1 - ws) * neg_weight

  if loss_type == 'moe':
    def loss(params, x, y, item_ids):
      expert_preds, gate_logits = preds_and_logits(
        params, x, item_ids, binary=True)
      log_probs = jax.nn.log_softmax(gate_logits, axis=0)
      squares = jnp.square(y[None] - expert_preds)
      log_summands = log_probs - 0.5 * squares / likelihood_var
      nll = - jax.nn.logsumexp(log_summands, axis=0)
      return jnp.mean(_get_weights(y) * nll)
  elif loss_type == 'mse':
    def loss(params, x, y, item_ids):
      preds = apply_fn(params, x, item_ids)
      return jnp.mean(_get_weights(y) * (y - preds) ** 2)
  elif loss_type == 'bce':
    def loss(params, x, y, item_ids):  # numerically stable bce w/ logits
      logits = apply_fn(params, x, item_ids)
      nll = jnp.maximum(0, logits) - y * logits
      nll += jnp.log1p(jnp.exp(-jnp.abs(logits)))
      return jnp.mean(_get_weights(y) * nll)
  else:
    raise NotImplementedError(f'`loss_type == "{loss_type}"` not implemented')

  @jit
  def update(step_id, opt_state, x, y, item_ids):
    params = get_params(opt_state)
    nll, grads = jax.value_and_grad(loss)(params, x, y, item_ids)
    opt_state = opt_update(step_id, grads, opt_state)
    return nll, opt_state

  return (init_fn, apply_fn), (opt_init, get_params), update


def load_model(save_path):
  with open(os.path.join(save_path, 'weights'), 'rb') as load_file:
    params = pickle.load(load_file)
  return jax.tree_map(jnp.array, params)


def save_model(params, save_path):
  np_params = jax.tree_map(np.array, params)
  with open(os.path.join(save_path, 'weights'), 'wb') as save_file:
    pickle.dump(np_params, save_file)


def main(_):
  key = random.PRNGKey(FLAGS.seed)
  rng = np.random.RandomState(FLAGS.seed)
  config_dict = {
    f.name: f.value for f in FLAGS.get_key_flags_for_module(sys.argv[0])
  }

  experiment_id = '-'.join([f'{k}_{v}' for k, v in config_dict.items()])
  experiment_id = hashlib.sha1(str.encode(experiment_id)).hexdigest()
  save_path = os.path.join(FLAGS.save_dir, str(experiment_id))
  if not os.path.exists(save_path):
    os.makedirs(save_path)


  logging.info('load data')

  train, valid, feat_dim = load_dataset(
    rng, data_dir=FLAGS.data_dir, feat_dim=FLAGS.feat_dim, n_arms=FLAGS.n_arms,
    skip_top=FLAGS.skip_top, k_train=FLAGS.k_train, n_valid=FLAGS.n_valid,
    add_bias=FLAGS.add_bias, move_to_device=FLAGS.move_to_device)

  (x_train, y_train, ids_train), (x_valid, y_valid) = train, valid
  ids_valid = jnp.arange(FLAGS.n_arms)  # predict all arms in non-binary mode
  config_dict['feat_dim'] = feat_dim


  logging.info('initialise model & optimiser')

  if FLAGS.emb_dim is None:
    emb_dim = max(10, int(np.ceil(FLAGS.n_arms ** 0.25)))
  else:
    emb_dim = FLAGS.emb_dim
  config_dict['emb_dim'] = emb_dim

  neg_count, pos_count = np.bincount(y_train.astype('int32'))
  if FLAGS.use_baseline_init:
    init_bias = np.array([  # compute baseline positivity rate
      y_train[ids_train == ii].mean() for ii in range(FLAGS.n_arms)])
    init_bias = np.maximum(1e-3, init_bias)  # regularisation
  else:
    init_bias = None

  key, init_key = random.split(key)
  (init_fn, apply_fn), (opt_init, get_params), update = init_model_and_optim(
    key=init_key, model_type=FLAGS.model_type, loss_type=FLAGS.loss_type,
    optimiser=FLAGS.optimiser, base_lr=FLAGS.lr, decay_steps=FLAGS.decay_steps,
    decay_rate=FLAGS.decay_rate, emb_dim=emb_dim, feat_dim=feat_dim,
    n_arms=FLAGS.n_arms, init_bias=init_bias, use_weighting=FLAGS.use_weighting,
    pos_count=pos_count, neg_count=neg_count, n_experts=FLAGS.n_experts,
    random_gate_init=FLAGS.random_gate_init, likelihood_var=FLAGS.likelihood_var,
    fixed_arm_pools=FLAGS.fixed_arm_pools)

  def predict_fn(*args, **kwargs):
    ret = apply_fn(*args, **kwargs)
    if FLAGS.loss_type == 'bce':
      # *multilabel* classification -> use sigmoid, not softmax
      ret = jax.nn.sigmoid(ret)
    return ret


  logging.info('train model')

  if FLAGS.stoch_grad:
    batch_size = min(FLAGS.batch_size, len(x_train))
    n_epochs = int(FLAGS.n_steps / (len(x_train) / batch_size))
  else:
    batch_size, n_epochs = len(x_train), FLAGS.n_steps

  key, train_key = random.split(key)
  evaluator, opt_params = train_model(
    key=train_key, rng=rng, x_train=x_train, y_train=y_train, x_valid=x_valid,
    y_valid=y_valid, ids_train=ids_train, ids_valid=ids_valid, init_fn=init_fn,
    opt_init=opt_init, predict_fn=predict_fn, get_params=get_params,
    update=update, n_epochs=n_epochs, batch_size=batch_size,
    early_stopping=FLAGS.early_stopping, eval_freq=FLAGS.eval_freq,
    stopping_criterion=FLAGS.stopping_criterion, model_save_path=save_path,
    move_to_device=FLAGS.move_to_device)

  if FLAGS.early_stopping:
    params = load_model(save_path)
  else:  # just use the final weights
    params = opt_params
    save_model(params, save_path)

  preds_valid = predict_fn(params, x_valid, ids_valid, binary=False)
  fy, fp = (np.array(f.ravel()) for f in (y_valid, preds_valid))
  evaluator['auc'] = roc_auc_score(fy, fp)
  evaluator['aps'] = average_precision_score(fy, fp)

  logging.info(f'auc {evaluator["auc"]:.4f} aps {evaluator["aps"]:.4f}')


  logging.info('save the experiment config & results')

  results_dict = {k: np.array(v) for k, v in evaluator.items()}
  for df, fname in zip((config_dict, results_dict), ('config', 'results')):
    with open(os.path.join(save_path, fname), 'wb') as save_file:
      pickle.dump(df, save_file)

  logging.info('experiment finished')


def train_model(
    key, rng, x_train, y_train, ids_train, x_valid, y_valid, ids_valid,
    init_fn, opt_init, predict_fn, get_params, update, n_epochs, batch_size,
    stopping_criterion, early_stopping, eval_freq, move_to_device,
    model_save_path):
  train_losses = []
  evaluator = Evaluator((
    'train_rmse', 'val_rmse', 'train_acc', 'val_acc', 'regret',
    'prec', 'recall', 'ndcg'))

  key, init_key = random.split(key)
  init_params = init_fn(init_key)
  opt_state = opt_init(init_params)

  substrings = stopping_criterion.split('@')
  stop_key, stop_idx, stop_val = substrings[0], None, None
  if len(substrings) == 2:
    stop_idx = int(substrings[1]) - 1  # convert to 0-based indexing
  elif len(substrings) > 2:
    raise ValueError(stopping_criterion)
  if stop_key in ('val_acc', 'auc', 'aps', 'prec', 'recall', 'ndcg'):
    is_1better2 = lambda a, b: a > b  # 'max'
  elif stop_key in ('val_rmse', 'regret'):
    is_1better2 = lambda a, b: a < b  # 'min'
  else:
    raise NotImplementedError(f'unknown max/min preference for "{stop_key}"')

  start_train = time.time()
  counter = itertools.count()
  for epoch_id in range(n_epochs):
    if move_to_device:
      key, perm_key = random.split(key)
      perm = random.permutation(perm_key, len(x_train))
    else:
      perm = rng.permutation(len(x_train))

    for batch_id in range(len(x_train) // batch_size):
      idx = perm[batch_id * batch_size:(batch_id + 1) * batch_size]
      x_batch, y_batch, ids_batch = x_train[idx], y_train[idx], ids_train[idx]

      step_id = next(counter)
      nll, opt_state = update(step_id, opt_state, x_batch, y_batch, ids_batch)
      train_losses.append(float(nll))

      if step_id % eval_freq == 0:
        train_duration = time.time() - start_train

        # eval
        start_eval = time.time()

        params = get_params(opt_state)
        preds_train = predict_fn(params, x_train, ids_train)
        preds_valid = predict_fn(params, x_valid, ids_valid, binary=False)
        evaluator.eval_fn(preds_train, preds_valid, y_train, y_valid)

        eval_duration = time.time() - start_eval
        logging.info(
          f'STEP {step_id} (epoch {epoch_id}): nll {nll:.4f} '
          f'train time {train_duration:.4f}s eval time {eval_duration:.4f}s')

        # early stopping
        if early_stopping:
          last_val = evaluator[stop_key][-1][stop_idx]
          if stop_val is None or is_1better2(last_val, stop_val):
            stop_val = last_val
            save_model(params, model_save_path)

        # restart the time for training
        start_train = time.time()

  # store the traning losses
  evaluator['train_loss'] = train_losses

  return evaluator, get_params(opt_state)


if __name__ == '__main__':
  app.run(main)
