import tensorflow as tf
from bayesflow.benchmarks import Benchmark

import sys
sys.path.append('../')

from helpers import get_amortized_setup
from settings import MODES

NUM_TRAINING_SIMS = 5000
NUM_VALIDATION_SIMS = 500

if __name__ == "__main__":

    # Hyperparameters
    epochs = 150
    batch_size = 64

    # Load benchmark with required settings
    benchmark = Benchmark('lotka_volterra', sim_kwargs={'subsample': None, 'flatten': False}, mode='posterior')

    # Load shared data
    train_sims = benchmark.generative_model(NUM_TRAINING_SIMS)
    val_sims = benchmark.generative_model(NUM_VALIDATION_SIMS)

    # Train networks for three settings
    for config in MODES:
        print(f'Starting training of {config[0]} network with a {config[1]} summary...')
        trainer = get_amortized_setup(benchmark, mode=config[0], summary_type=config[1])
        _ = trainer.train_offline(
            train_sims, epochs=epochs, batch_size=batch_size, validation_sims=val_sims
        )
        print(f'Completed training of {config[0]} network with a {config[1]} summary...')
        tf.keras.backend.clear_session()
