R"""Script for running gradient descent on synthetic datasets.


cd ~/Desktop/projects/zonotopic_relu
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/zonotopic_relu


python3 scripts/synthetic/gradient_descent.py \
    --outdir="/tmp" \
    --configs_path="exps.synthetic.gradient_descent_configs.CONFIGS" \
    --config="test" \
    --n_runs=2

"""
import dataclasses
import json
import os
from pydoc import locate

from absl import app
from absl import flags
from absl import logging

if True:
    # For whatever reason, it looks like using the GPU is slower for these
    # small models and/or style of feeding data.
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf

from xoid.datasets import synthetic


FLAGS = flags.FLAGS

_CONFIGS_PATH = "exps.synthetic.gradient_descent_configs.CONFIGS"

if __name__ == "__main__":
    # Directory should already exist.
    flags.DEFINE_string('outdir', None, 'Path directory to create where we will write output to.')

    flags.DEFINE_string('configs_path', _CONFIGS_PATH, 'Python path to configs dict.')
    flags.DEFINE_string('config', None, 'Name of the entry in the configs dict to use as configuration.')

    flags.DEFINE_integer('n_runs', 1, 'Number of times to repeat the experiment.')

    flags.mark_flags_as_required(['outdir', 'configs_path', 'config'])


@dataclasses.dataclass()
class Config:
    name: str

    m_gen: int
    m_train: int
    d: int

    lr: float
    n_steps: int


def make_model(cfg):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(cfg.m_train, activation='relu'),
        tf.keras.layers.Dense(1, activation=None),
    ])

    model(tf.keras.Input([cfg.d]))
    model.compile(
        optimizer=tf.keras.optimizers.SGD(cfg.lr),
        loss=tf.keras.losses.MeanSquaredError()
    )
    return model


def do_run(cfg, run_index):
    N = (cfg.d + 1) * cfg.m_gen
    X, Y = synthetic.make_dataset(cfg.d, cfg.m_gen, N)
    Y = Y[:, None]

    ds = tf.data.Dataset.from_tensors((tf.cast(X, tf.float32), tf.cast(Y, tf.float32)))

    model = make_model(cfg)
    n_epochs = 10
    history = model.fit(ds.repeat(cfg.n_steps // n_epochs), epochs=n_epochs, validation_data=ds)
    loss = model.evaluate(ds)

    return loss, history.history


def main(_):
    cfg = locate(FLAGS.configs_path)[FLAGS.config]

    losses = []
    histories = []
    for i in range(FLAGS.n_runs):
        loss, history = do_run(cfg, i)
        losses.append(loss)
        histories.append(history)

    results = {
        'final_losses': losses,
        'histories': histories,
        'config': dataclasses.asdict(cfg),
    }

    filepath = os.path.join(FLAGS.outdir, f'gd_{cfg.name}.json')
    filepath = os.path.expanduser(filepath)
    with open(filepath, 'w') as f:
        json.dump(results, f)


if __name__ == "__main__":
    app.run(main)
