from functools import partial

import tensorflow as tf

from settings import MODEL_NAMES
from models import get_model
from configuration import configurator_model_comparison

from bayesflow.simulation import MultiGenerativeModel
from bayesflow.networks import SetTransformer, PMPNetwork
from bayesflow.amortizers import AmortizedModelComparison
from bayesflow.trainers import Trainer
from bayesflow.configuration import DefaultModelComparisonConfigurator


NUM_TRAINING_SIMS = 50000
NUM_VALIDATION_SIMS = 500
BATCH_SIZE = 128
EPOCHS = 250


def get_trainer(mode, num_models, summary_dim=30):
    """Returns a ready-to-call trainer for amortized model comparison."""

    base_config = DefaultModelComparisonConfigurator(len(MODEL_NAMES))
    configurator = partial(configurator_model_comparison, mode=mode, base_config=base_config)

    if mode == 'expert':
        summary_net = None
    elif mode == 'learner':
        summary_net = SetTransformer(input_dim=3, summary_dim=summary_dim)
    elif mode == 'hybrid':
        summary_net = SetTransformer(input_dim=3, summary_dim=summary_dim//2)
    else:
        raise NotImplementedError('Unknown mode')

    inference_net = PMPNetwork(
        num_models=num_models,
        dense_args=dict(units=256, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(5e-5))
    )

    amortizer = AmortizedModelComparison(inference_net, summary_net)
    trainer = Trainer(
        amortizer=amortizer,
        configurator=configurator,
        checkpoint_path=f'checkpoints/model_comparison/{mode}_{NUM_TRAINING_SIMS}'
    )
    return trainer


if __name__ == '__main__':

    # Create model wrappers
    all_models = [get_model(m) for m in MODEL_NAMES]
    meta_model = MultiGenerativeModel(all_models)

    # Create offline train and validation data
    train_data = meta_model(NUM_TRAINING_SIMS)
    validation_data = meta_model(NUM_VALIDATION_SIMS)

    # Train networks for each mode
    for summary_mode in ['expert', 'learner', 'hybrid']:
        trainer = get_trainer(summary_mode, len(all_models))
        _ = trainer.train_offline(train_data, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_sims=validation_data)
        tf.keras.backend.clear_session()










