import tensorflow as tf
import os
import sys
sys.path.append('../')


from helpers import get_amortized_setup
from settings import MODES, SIMULATION_BUDGETS, MODEL_NAMES
from models import get_model

BATCH_SIZE = 32
VALIDATION_SIMS = 500
EPOCHS = 100


def train_modes(model, sim_budget, training_sims, validation_sims):
    """Runs a single training configuration."""

    # Train for all modes
    for mode in MODES:
        # Only if checkpoint directory does not exist
        print(f'Starting training of {mode} network with a budget of {sim_budget} simulations for model {model.name}...')
        ckpt_path = f'./checkpoints/estimation/{model.name}_{mode}_{sim_budget}'
        if not os.path.exists(ckpt_path):
            trainer = get_amortized_setup(model, mode=mode, budget=budget)
            _ = trainer.train_offline(
                training_sims, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_sims=validation_sims)
            print('Completed...')
            tf.keras.backend.clear_session()
            del trainer
        else:
            print(f'Skipping, since {ckpt_path} exists...')


if __name__ == "__main__":
    # Loop through models
    for model_name in MODEL_NAMES:
        # Create model and shared validation set
        model = get_model(model_name)
        val_sims = model(VALIDATION_SIMS)
        # Loop through simulation budgets
        for budget in SIMULATION_BUDGETS:
            # Create simulation data
            train_sims = model(budget)
            model = get_model(model_name)
            train_modes(model, budget, train_sims, val_sims)
