"""Example script to train and evaluate a network."""

from absl import app

import haiku as hk
import jax.numpy as jnp
import numpy as np

from training import constants
from training import curriculum as curriculum_lib
from training import training
from training import utils


def main(unused_argv) -> None:
  # Change your hyperparameters here. See constants.py for possible tasks and
  # architectures.
  batch_size = 128
  sequence_length = 40
  task = 'even_pairs'
  architecture = 'tape_rnn'
  architecture_params = {
      'hidden_size': 256, 'memory_cell_size': 8, 'memory_size': 40}

  # Create the task.
  curriculum = curriculum_lib.UniformCurriculum(
      values=list(range(1, sequence_length + 1)))
  task = constants.TASK_BUILDERS[task]()

  # Create the model.
  is_autoregressive = (architecture == 'transformer')
  model = constants.MODEL_BUILDERS[architecture](
      output_size=task.output_size,
      return_all_outputs=True,
      **architecture_params)
  model = utils.wrap_model_with_pad(
      model=model, generalization_task=task,
      computation_steps_mult=0, single_output=True,
      is_autoregressive=is_autoregressive)
  model = hk.transform(model)

  # Create the loss and accuracy based on the pointwise ones.
  def loss_fn(output, target):
    loss = jnp.mean(jnp.sum(task.pointwise_loss_fn(output, target), axis=-1))
    return loss, {}

  def accuracy_fn(output, target):
    mask = task.accuracy_mask(target)
    return jnp.sum(mask * task.accuracy_fn(output, target)) / jnp.sum(mask)

  # Create the final training parameters.
  training_params = training.ClassicTrainingParams(
      seed=0,
      model_init_seed=0,
      training_steps=10_000,
      log_frequency=100,
      length_curriculum=curriculum,
      batch_size=batch_size,
      task=task,
      model=model,
      loss_fn=loss_fn,
      learning_rate=1e-3,
      l2_weight=0.,
      accuracy_fn=accuracy_fn,
      compute_full_range_test=True,
      max_range_test_length=100,
      range_test_total_batch_size=512,
      range_test_sub_batch_size=64,
      is_autoregressive=is_autoregressive)

  training_worker = training.TrainingWorker(training_params, use_tqdm=True)
  _, eval_results, _ = training_worker.run()

  # Gather results and print final score.
  accuracies = [r['accuracy'] for r in eval_results]
  score = np.mean(accuracies[sequence_length + 1:])
  print(f'Network score: {score}')


if __name__ == '__main__':
  app.run(main)

