from bayesflow.simulation import Prior, Simulator, GenerativeModel

from priors import *
from likelihoods import *


def get_model(model_type):
    """Helper function to fetch the prior, simulator and patch into BayesFlow format."""

    if model_type == "m1a":
        prior = Prior(prior_fun=draw_prior_m1a)
        simulator = Simulator(simulator_fun=simulate_trials_m1a)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m1a')
        return model
    if model_type == "m1b":
        prior = Prior(prior_fun=draw_prior_m1b)
        simulator = Simulator(simulator_fun=simulate_trials_m1b)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m1b')
        return model
    if model_type == "m1c":
        prior = Prior(prior_fun=draw_prior_m1c)
        simulator = Simulator(simulator_fun=simulate_trials_m1c)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m1c')
        return model
    if model_type == "m2":
        prior = Prior(prior_fun=draw_prior_m2)
        simulator = Simulator(simulator_fun=simulate_trials_m2)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m2')
        return model
    if model_type == "m3":
        prior = Prior(prior_fun=draw_prior_m3)
        simulator = Simulator(simulator_fun=simulate_trials_m3)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m3')
        return model
    if model_type == "m4a":
        prior = Prior(prior_fun=draw_prior_m4a)
        simulator = Simulator(simulator_fun=simulate_trials_m4a)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m4a')
        return model
    if model_type == "m4b":
        prior = Prior(prior_fun=draw_prior_m4b)
        simulator = Simulator(simulator_fun=simulate_trials_m4b)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m4b')
        return model
    if model_type == "m5":
        prior = Prior(prior_fun=draw_prior_m5)
        simulator = Simulator(simulator_fun=simulate_trials_m5)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m5')
        return model
    if model_type == "m6":
        prior = Prior(prior_fun=draw_prior_m6)
        simulator = Simulator(simulator_fun=simulate_trials_m6)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m6')
        return model
    if model_type == "m7":
        prior = Prior(prior_fun=draw_prior_m7)
        simulator = Simulator(simulator_fun=simulate_trials_m7)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m7')
        return model
    if model_type == "m8":
        prior = Prior(prior_fun=draw_prior_m8)
        simulator = Simulator(simulator_fun=simulate_trials_m8)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m8')
        return model
    if model_type == "m9":
        prior = Prior(prior_fun=draw_prior_m9)
        simulator = Simulator(simulator_fun=simulate_trials_m9)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m9')
        return model
    if model_type == "m10":
        prior = Prior(prior_fun=draw_prior_m10)
        simulator = Simulator(simulator_fun=simulate_trials_m10)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m10')
        return model
    if model_type == "m11":
        prior = Prior(prior_fun=draw_prior_m11)
        simulator = Simulator(simulator_fun=simulate_trials_m11)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m11')
        return model
    if model_type == "m12":
        prior = Prior(prior_fun=draw_prior_m12)
        simulator = Simulator(simulator_fun=simulate_trials_m12)
        model = GenerativeModel(prior, simulator, skip_test=True, name='m12')
        return model
