from library.examples.mpi_faust.faust_data_set import load_preprocessed_faust
from library.layers.angular_max_pooling import AngularMaxPooling
from library.layers.conv_chi_squared import ConvChiSquared
from library.models.intrinsic_model import ImCNN
from library.utils.measures import princeton_benchmark

from experiment_scripts.preprocess_data import PREPROCESS_TARGET_DIR, REFERENCE_MESH_PATH

from tensorflow import keras

import tensorflow as tf


def define_model(signal_dim, kernel_size):
    signal_input = keras.layers.Input(shape=(signal_dim,), name="signal")
    bc_input = keras.layers.Input(shape=kernel_size + (3, 2), name="bc")

    amp = AngularMaxPooling()

    signal = ConvChiSquared(
        output_dim=128,
        amt_kernel=1,
        kernel_radius=0.028,
        activation="relu",
        name="gc_0",
        splits=10,
        rotation_delta=1,
        dof=3
    )([signal_input, bc_input])
    signal = amp(signal)

    signal = ConvChiSquared(
        output_dim=128,
        amt_kernel=2,
        kernel_radius=0.028,
        activation="relu",
        name="gc_1",
        splits=10,
        rotation_delta=1,
        dof=3
    )([signal, bc_input])
    signal = amp(signal)

    signal = ConvChiSquared(
        output_dim=128,
        amt_kernel=1,
        kernel_radius=0.028,
        activation="relu",
        name="gc_2",
        splits=10,
        rotation_delta=1,
        dof=3
    )([signal, bc_input])
    signal = amp(signal)

    output = keras.layers.Dense(6890)(signal)

    model = ImCNN(splits=1, inputs=[signal_input, bc_input], outputs=[output])
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    opt = keras.optimizers.Adam(learning_rate=0.00092318)
    model.compile(optimizer=opt, loss=loss, metrics=["sparse_categorical_accuracy"])

    return model


def main(signal_dim, kernel_size, zip_file_path, logging_dir, reference_mesh_path):
    # Model
    imcnn = define_model(signal_dim, kernel_size)
    print(tf.config.list_physical_devices('GPU'))
    print(imcnn.summary())
    print(f"Used data: {zip_file_path}.zip")

    # Load data
    preprocess_zip = f"{zip_file_path}.zip"
    train_data = load_preprocessed_faust(preprocess_zip, signal_dim=signal_dim, kernel_size=kernel_size, set_type=0)
    val_data = load_preprocessed_faust(preprocess_zip, signal_dim=signal_dim, kernel_size=kernel_size, set_type=1)

    # Define callbacks
    stop = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
    tb = tf.keras.callbacks.TensorBoard(
        log_dir=f"{logging_dir}/tensorboard",
        histogram_freq=1,
        write_graph=False,
        write_steps_per_second=True,
        update_freq="epoch",
        profile_batch=(1, 1 * 70)  # Batch as in one tick in the bar, not what the layer perceives as a batch
    )

    imcnn.fit(x=train_data, callbacks=[stop, tb], validation_data=val_data, epochs=200)
    imcnn.save(f"{logging_dir}/best_model")

    test_data = load_preprocessed_faust(preprocess_zip, signal_dim=signal_dim, kernel_size=kernel_size, set_type=2)
    princeton_benchmark(
        imcnn=imcnn,
        test_dataset=test_data,
        ref_mesh_path=reference_mesh_path,
        file_name=f"{logging_dir}/best_model_benchmark"
    )


if __name__ == "__main__":
    main(
        signal_dim=544,
        kernel_size=(5, 8),
        zip_file_path=PREPROCESS_TARGET_DIR,
        logging_dir="./logs_chi_sqrd_3",
        reference_mesh_path=REFERENCE_MESH_PATH
    )
