R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 -i local_scripts/synth_ds/struct_net001.py


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/synth_ds/struct_net001.py

"""
import dataclasses
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from em.datasets.synthetic import structured_network as sn


nodes_by_level = sn.make_uniformly_random_nodes_by_level(
    n_levels=2,
    # n_levels=5,
    min_nodes=8,
    max_nodes=10
)
sn.add_uniformly_random_dag_connections(
    nodes_by_level,
    # p_connection=0.2,
    p_connection=[0.1],
)
nodes_by_level = sn.remove_sources_not_at_first_level(nodes_by_level)
sn.assign_uniformly_random_units(
    nodes_by_level,
    min_units=8,
    max_units=32,
)
dag = sn.Dag(
    nodes_by_level=nodes_by_level,
)

dnn = sn.DagNeuralNetwork(
    dag=dag,
    # n_classes=3,
    n_classes=32,
    activation=tf.nn.relu,
)


q = dnn(tf.random.normal([32, dnn.get_input_size()]))


def get_ds(dnn, buffer_size: int = 2048):

    def gen():
        while True:
            X = tf.random.normal([buffer_size, dnn.get_input_size()])
            logits = dnn(X)
            probs = tf.math.softmax(logits, axis=-1)
            for x, y in zip(X, probs):
                yield x, y

    return tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=[dnn.get_input_size()], dtype=tf.float32),
            tf.TensorSpec(shape=[dnn.n_classes], dtype=tf.float32),
        ),
    )


model = tf.keras.Sequential([
    *[
        # tf.keras.layers.Dense(1024, activation='relu')
        tf.keras.layers.Dense(2 * 1024, activation='relu')
        # tf.keras.layers.Dense(4 * 1024, activation='relu')
        # for _ in range(dag.n_levels)
        for _ in range(2 * dag.n_levels)
    ],
    tf.keras.layers.Dense(dnn.n_classes),
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3, clipnorm=1),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
)

# model.fit(get_ds(dnn).batch(128), epochs=16, steps_per_epoch=512)
model.fit(get_ds(dnn).batch(128), epochs=256, steps_per_epoch=512)
