from nle.env.base import DUNGEON_SHAPE

from . import language, student, teacher


def create_models(env, FLAGS, device="cpu"):
    num_actions = len(env._actions)
    model = student.Student(
        DUNGEON_SHAPE,
        num_actions,
        FLAGS,
        device,
        use_intrinsic_rewards=FLAGS.generator,
    )

    if FLAGS.language_goals is None:
        generator_model = teacher.Teacher(
            DUNGEON_SHAPE,
            FLAGS,
            device,
        )
    else:
        generator_model = teacher.LanguageTeacher(
            DUNGEON_SHAPE,
            FLAGS,
            device,
        )

    return model, generator_model


def create_rnd_model(env, FLAGS):
    predictor = teacher.RNDNet(
        DUNGEON_SHAPE,
        len(env._actions),
        FLAGS.minihack,
        FLAGS.device,
        message_novelty=FLAGS.separate_message_novelty,
    )
    predictor = predictor.to(FLAGS.device)
    return predictor
