"""Training and evaluation loops for an experiment."""

import time
from typing import Any, Mapping, Text, Type, Union

import jsonlines
from absl import app
from absl import flags
from absl import logging

import jax
import jax.numpy as jnp

import numpy as np

from rl.configs import mc_policy_eval as rl_config
from rl import policy_eval_experiment


flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval')
flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use')
flags.DEFINE_integer('train_steps', 1000, 'Number of training steps')
flags.DEFINE_integer('batch_size', 4096, 'Total batch size')
flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.')
flags.DEFINE_string('dataset', None, 'Path to dataset.')

flags.DEFINE_string('checkpoint_root', None, 'The directory to save checkpoints to.')
flags.DEFINE_string('eval_log_path', None, 'The directory in evaluation logs are saved.')
flags.DEFINE_integer("task_id", None, 'The ID of this task used when determining hyperparameters.')
flags.DEFINE_integer("num_tasks", None, 'The number of tasks part of this run.')
flags.DEFINE_string("run_configs", None, 'Path to jsonline file with run hyperparameters.')


def _check_if_all_or_neither_none(flag_values):
    is_none = [v is None for v in flag_values]
    return all(is_none) or not any(is_none)


flags.register_multi_flags_validator(
    ["task_id", "num_tasks", "run_configs"],
    lambda flag_dict: _check_if_all_or_neither_none(flag_dict.values()),
    "Both task_id, num_tasks and run_configs need to be given, or none should be given.",
)

FLAGS = flags.FLAGS


Experiment = Type[policy_eval_experiment.PolicyEvalExperiment]


def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
    """The main training loop.

    This loop periodically saves a checkpoint to be evaluated in the eval_loop.

    Args:
      experiment_class: the constructor for the experiment (either byol_experiment
      or eval_experiment).
      config: the experiment config.
    """
    config = dict(config)
    ss = np.random.SeedSequence(config.pop("random_seed"))
    loop_seed, experiment_seed = ss.spawn(2)

    experiment = experiment_class(random_seed=experiment_seed, **config)
    rng = jnp.array(loop_seed.generate_state(2))
    step = 0

    host_id = jax.process_index()
    last_logging = time.time()
    if config['checkpointing_config']['use_checkpointing']:
        checkpoint_data = experiment.load_checkpoint()
        if checkpoint_data is None:
            step = 0
        else:
            step, rng = checkpoint_data

    scalars = {}
    local_device_count = jax.local_device_count()
    not_diverged = True
    while step < config['max_steps'] and not_diverged:
        step_rng, rng = tuple(jax.random.split(rng))
        # Broadcast the random seeds across the devices
        step_rng_device = jax.random.split(step_rng, num=jax.device_count())
        step_rng_device = step_rng_device[
                          host_id * local_device_count:(host_id + 1) * local_device_count]
        step_device = np.broadcast_to(step, [local_device_count])

        # Periodic snapshot and possible evaluation
        save_period = config['snapshot_config']['save_every_nsteps']
        if save_period and (step % save_period == 0):
            # this done before the step to capture initial performance
            experiment.save_snapshot(global_step=step_device, rng=step_rng_device)

        # Perform a training step and get scalars to log.
        scalars = experiment.step(global_step=step_device, rng=step_rng_device)

        # Checkpointing and possibly evaluation
        if config['checkpointing_config']['use_checkpointing']:
            experiment.save_checkpoint(step, rng)

        # Logging
        current_time = time.time()
        if current_time - last_logging > FLAGS.log_tensors_interval:
            logging.info('Step %d: %s', step, scalars)
            last_logging = current_time
            not_diverged = jax.tree_util.tree_reduce(
                lambda finite, x: finite and jnp.all(jnp.isfinite(x)), scalars, True)

        step += 1
    logging.info('Saving final checkpoint')
    logging.info('Step %d: %s', step, scalars)
    experiment.save_checkpoint(step, rng)


def main(_):
    if FLAGS.worker_tpu_driver:
        jax.config.update('jax_xla_backend', 'tpu_driver')
        jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver)
        logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices())

    experiment_class = policy_eval_experiment.PolicyEvalExperiment

    if FLAGS.task_id is None:
        configs = [
            rl_config.get_config(
                FLAGS.dataset, FLAGS.train_steps, FLAGS.batch_size, FLAGS.checkpoint_root)
        ]
    else:
        with jsonlines.open(FLAGS.run_configs) as reader:
            run_params = list(reader)[FLAGS.task_id::FLAGS.num_tasks]

        configs = []
        for params in run_params:
            config = rl_config.get_config(
                checkpoint_root=FLAGS.checkpoint_root,
                eval_log_path=FLAGS.eval_log_path,
                **params,
            )
            if rl_config.validate_config(config):
                configs.append(config)

    for config in configs:
        logging.info('Starting run logged in %s', config["evaluation_config"]["log_path"])
        train_loop(experiment_class, config)


if __name__ == '__main__':
    app.run(main)
