R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 -i local_scripts/divis/divis_dev001.py


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/divis/divis_dev001.py
CUDA_VISIBLE_DEVICES=1 python -i local_scripts/divis/divis_dev001.py

"""
from importlib import reload
import itertools

import numpy as np
import tensorflow as tf

from em.datasets import divisibility as div_ds

# divisors, dividends, labels = div_ds._make(buffer_size=4 * 1024)
# ds = div_ds.create_ds(max_divisor=20)
# ds = div_ds.create_ds(buffer_size=1024 * 1024, min_divisor=2, max_divisor=3)
ds = div_ds.create_ds(buffer_size=1024 * 1024, min_divisor=2, max_divisor=3, max_dividend=99_999_999)
# train_ds = ds.batch(256)
train_ds = ds.batch(2048)
# train_ds = ds.batch(8 * 1024)


class EmbeddingsLayer(tf.keras.layers.Layer):
    def __init__(self, embeddings_size: int, alphabet_size: int = 10, **kwargs):
        super().__init__(**kwargs)

        self.embeddings_table = self.add_weight(
            name='embeddings_table',
            shape=[alphabet_size, embeddings_size],
            trainable=True,
        )

    def call(self, token_ids):
        return tf.gather(self.embeddings_table, token_ids)


class ResDense(tf.keras.layers.Dense):

    def call(self, x, *args, **kwargs):
        return x + super().call(x)


class ResBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, d_ff, activation):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(d_ff, activation=activation)
        self.dense2 = tf.keras.layers.Dense(d_model, activation=activation)

    def call(self, x):
        return x + self.dense2(self.dense1(x))


# model = tf.keras.Sequential([
#     EmbeddingsLayer(embeddings_size=32),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(64, activation='relu'),
#     tf.keras.layers.Dense(64, activation='relu'),
#     tf.keras.layers.Dense(2, activation=None),
# ])


# model = tf.keras.Sequential([
#     EmbeddingsLayer(embeddings_size=512),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(2, activation=None),
# ])


# model = tf.keras.Sequential([
#     EmbeddingsLayer(embeddings_size=512),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.LayerNormalization(),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.LayerNormalization(),
#     tf.keras.layers.Dense(2, activation=None),
# ])

# model = tf.keras.Sequential([
#     EmbeddingsLayer(embeddings_size=1024),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(8 * 1024, activation='relu'),
#     tf.keras.layers.Dense(8 * 1024, activation='relu'),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(2 * 1024, activation='relu'),
#     tf.keras.layers.Dense(2 * 1024, activation='relu'),
#     tf.keras.layers.Dense(2, activation=None),
# ])


# model = tf.keras.Sequential([
#     EmbeddingsLayer(embeddings_size=1024),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     ResDense(4 * 1024, activation='relu'),
#     ResDense(4 * 1024, activation='relu'),
#     ResDense(4 * 1024, activation='relu'),
#     ResDense(4 * 1024, activation='relu'),
#     ResDense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(2, activation=None),
# ])

# layers = [
#     EmbeddingsLayer(embeddings_size=1024),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(1024, activation='relu'),
#     tf.keras.layers.LayerNormalization(),
# ]
# for _ in range(24):
#     layers.append(ResDense(1024, activation='relu'))
#     layers.append(tf.keras.layers.LayerNormalization())
# layers.append(tf.keras.layers.Dense(2, activation=None))
# model = tf.keras.Sequential(layers)


# model = tf.keras.Sequential([
#     EmbeddingsLayer(embeddings_size=32),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(32 * 1024, activation='relu'),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(4 * 1024, activation='relu'),
#     tf.keras.layers.Dense(2, activation=None),
# ])


d_model = 2 * 1024
d_ff = 8 * 1024
model = tf.keras.Sequential([
    EmbeddingsLayer(embeddings_size=32),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(d_model, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    ResBlock(d_model, d_ff, activation='relu'),
    tf.keras.layers.Dense(2, activation=None),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4, clipnorm=0.1),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)


# batch_size = 2048
batch_size = 8 * 1024

# max_divisor = 3
max_divisor = 7

train_ds = div_ds.create_ds(
    buffer_size=1024 * 1024,
    min_divisor=2,
    max_divisor=max_divisor,
    force_n_dividend_digits=8,
    max_dividend=9_999,
).batch(batch_size)
model.fit(
    train_ds,
    steps_per_epoch=128,
    epochs=2,
)

train_ds = div_ds.create_ds(
    buffer_size=1024 * 1024,
    min_divisor=2,
    max_divisor=max_divisor,
    force_n_dividend_digits=8,
    max_dividend=99_999,
).batch(batch_size)
model.fit(
    train_ds,
    steps_per_epoch=128,
    epochs=2,
)

train_ds = div_ds.create_ds(
    buffer_size=1024 * 1024,
    min_divisor=2,
    max_divisor=max_divisor,
    force_n_dividend_digits=8,
    max_dividend=999_999,
).batch(batch_size)
model.fit(
    train_ds,
    steps_per_epoch=128,
    epochs=2,
)

train_ds = div_ds.create_ds(
    buffer_size=1024 * 1024,
    min_divisor=2,
    max_divisor=max_divisor,
    force_n_dividend_digits=8,
    max_dividend=9_999_999,
).batch(batch_size)
model.fit(
    train_ds,
    steps_per_epoch=128,
    epochs=2,
)

train_ds = div_ds.create_ds(
    buffer_size=1024 * 1024,
    min_divisor=2,
    max_divisor=max_divisor,
    force_n_dividend_digits=8,
    max_dividend=99_999_999,
).batch(batch_size)
model.fit(
    train_ds,
    steps_per_epoch=128,
    epochs=2,
)


# Curiculum learning works great!
