import copy

import torch

from .. import buffers as B
from . import babyai


def create_models_and_buffers(env, FLAGS):
    model, generator_model = babyai.create_models(env, FLAGS)
    num_instrs = len(babyai.language.INSTRS)
    obs_spaces = {
        "frame": (env.observation_space.shape, torch.uint8),
        "partial_frame": (env.partial_observation_space.shape, torch.uint8),
    }

    learner_model = copy.deepcopy(model).to(FLAGS.device)
    learner_generator_model = copy.deepcopy(generator_model).to(FLAGS.device)

    buffers = B.create_buffers(
        obs_spaces,
        model.num_actions,
        generator_model.logits_size,
        generator_model.raw_goal_size,
        num_instrs,
        FLAGS,
    )

    return model, generator_model, learner_model, learner_generator_model, buffers