import tensorflow as tf
from bayesflow.benchmarks import Benchmark

import sys
sys.path.append('../')

from helpers import get_amortized_setup
from settings import MODES
from heteroskedastic import get_pretrained

NUM_TRAINING_SIMS = 1000
NUM_VALIDATION_SIMS = 500

if __name__ == "__main__":

    # Hyperparameters
    epochs = 200
    batch_size = 32

    # Load benchmark with required settings
    benchmark = Benchmark('lotka_volterra', sim_kwargs={'subsample': None, 'flatten': False}, mode='posterior')

    # Load shared data
    train_sims = benchmark.generative_model(NUM_TRAINING_SIMS)
    val_sims = benchmark.generative_model(NUM_VALIDATION_SIMS)

    # Add in expert statistics
    expert, config = get_pretrained(benchmark)
    train_summaries = expert.predict(config(train_sims)['summary_conditions'])
    val_summaries = expert.predict(config(val_sims)['summary_conditions'])
    train_sims['summaries'] = train_summaries.numpy()
    val_sims['summaries'] = val_summaries.numpy()

    # Train networks for three settings
    for config in MODES:
        print(f'Starting training of {config[0]} network with a {config[1]} summary...')
        trainer = get_amortized_setup(benchmark,
                                      mode=config[0],
                                      summary_type=config[1],
                                      ai_expert=True,
                                      budget=NUM_TRAINING_SIMS,
                                      num_total_summaries=16)
        _ = trainer.train_offline(
            train_sims, epochs=epochs, batch_size=batch_size, validation_sims=val_sims
        )
        print(f'Completed training of {config[0]} network with a {config[1]} summary...')
        tf.keras.backend.clear_session()
