R"""Script for training the divisibility model and saving to disk."""
import dataclasses
import os

from absl import app
from absl import flags
from absl import logging

import tensorflow as tf

from em.datasets import divisibility as divis_ds
from em.models import divis_models
from em.util import hdf5_util

FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to the .h5 file to save the results and weights to.")


flags.DEFINE_integer("embeddings_size", None, "")
flags.DEFINE_integer("n_layers", None, "")

flags.DEFINE_string("layer_config", None, "")
flags.DEFINE_string("activation_fn", "relu", "")


flags.DEFINE_integer("min_divisor", 2, "")
flags.DEFINE_integer("max_divisor", 13, "")
flags.DEFINE_integer("min_dividend", 100, "")
flags.DEFINE_integer("max_dividend", 999_999_999, "")


flags.DEFINE_integer('steps_per_epoch', None, "")
flags.DEFINE_list('curriculum_epochs', None, 'TODO: Add desc')

flags.DEFINE_integer('batch_size', 2048, "")
flags.DEFINE_float('learning_rate', 1e-4, "")
flags.DEFINE_float('clipnorm', 0.1, "")


_DS_BUFFER_SIZE = 4 * 1024 * 1024


def get_model():
    config = divis_models.DivisModelConfig(
        layer_config=FLAGS.layer_config,
        n_layers=FLAGS.n_layers,
        embeddings_size=FLAGS.embeddings_size,
        activation_fn=FLAGS.activation_fn,
    )
    model = divis_models.create_model(config)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate, clipnorm=FLAGS.clipnorm),
        metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
    return config, model


def get_final_dataset_config():
    return divis_ds.DivisibilityDatasetConfig(
        min_divisor=FLAGS.min_divisor,
        max_divisor=FLAGS.max_divisor,
        min_dividend=FLAGS.min_dividend,
        max_dividend=FLAGS.max_dividend,
    )


def ds_from_config(config: divis_ds.DivisibilityDatasetConfig):
    ds = divis_ds.create_ds(config, buffer_size=_DS_BUFFER_SIZE)
    ds = ds.batch(FLAGS.batch_size)
    return ds


def curriculum_stage(model, final_ds_config, n_digits: int, n_epochs: int):
    ds_config = dataclasses.replace(
        final_ds_config,
        force_n_dividend_digits=final_ds_config.n_dividend_digits,
        max_dividend=10**n_digits - 1,
    )
    train_ds = ds_from_config(ds_config)
    # TODO: Add some quick validation data so we get scores at the end of each epoch.
    history = model.fit(
        train_ds,
        steps_per_epoch=FLAGS.steps_per_epoch,
        epochs=n_epochs,
    )
    return history.history


def main(_):
    model_config, model = get_model()
    final_ds_config = get_final_dataset_config()

    curriculum_epochs = [int(e) for e in FLAGS.curriculum_epochs]
    n_curriculum_stages = len(curriculum_epochs)

    stage_histories = []
    for i, n_epochs in enumerate(curriculum_epochs):
        n_digits = final_ds_config.n_dividend_digits - n_curriculum_stages + 1 + i
        stage_history = curriculum_stage(model, final_ds_config, n_digits=n_digits, n_epochs=n_epochs)
        stage_histories.append(stage_history)

    metadata = {
        'model_config': dataclasses.asdict(model_config),
        'final_ds_config': dataclasses.asdict(final_ds_config),
        'training_config': {
            'curriculum_epochs': curriculum_epochs,
            'steps_per_epoch': FLAGS.steps_per_epoch,
            'batch_size': FLAGS.batch_size,
            'learning_rate': FLAGS.learning_rate,
            'clipnorm': FLAGS.clipnorm,
        },
        'stage_histories': stage_histories,
    }

    output_path = os.path.expanduser(FLAGS.output_path)
    hdf5_util.save_np_arrays_with_metadata(output_path, model.get_weights(), metadata)


if __name__ == "__main__":
    app.run(main)
