"""Utilities for small CIFAR model mNPEFF."""
import tensorflow as tf

from em import datasets as em_datasets


###############################################################################

def make_model(n_classes: int = 2):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(n_classes),
    ])
    setattr(model, 'num_labels', n_classes)
    return model


def compile_model(model, learning_rate: float):
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])


###############################################################################

def load_datasets(batch_size: int):
    train_ds = em_datasets.load('cifar10/binarized', split='train', tokenizer=None, sequence_length=None)
    train_ds = train_ds.repeat().shuffle(1000).batch(batch_size)

    val_ds = em_datasets.load('cifar10/binarized', split='test', tokenizer=None, sequence_length=None)
    val_ds = val_ds.batch(batch_size)

    return train_ds, val_ds
