from . import language, student, teacher


def create_models(env, FLAGS):
    lang_args = language.get_lang(env, FLAGS)

    if FLAGS.language_goals is not None:
        teacher_fn = teacher.LanguageTeacher
        student_fn = student.Student
    else:
        teacher_fn = teacher.Teacher
        student_fn = student.Student

    if FLAGS.partial_obs:
        obs_space = env.partial_observation_space.shape
    else:
        obs_space = env.observation_space.shape

    model = student_fn(
        obs_space,
        env.action_space.n,
        FLAGS,
        **lang_args,
        use_intrinsic_rewards=FLAGS.generator,
    )

    generator_model = teacher_fn(
        env.observation_space.shape,
        env.width,
        env.height,
        FLAGS,
        **lang_args,
    )

    return model, generator_model


def create_rnd_model(env, FLAGS):
    predictor = teacher.Predictor(
        env.observation_space.shape,
        FLAGS,
    )
    return predictor


def create_message_rnd_model(env, FLAGS):
    lang_args = language.get_lang(env, FLAGS)
    lang_predictor = teacher.LangPredictor(
        FLAGS,
        **lang_args,
    )
    return lang_predictor
