from typing import Dict

import torch

from erl_lib.agent.model_based.modules.model import Model
from erl_lib.util.env import make_envs


class EnvModel(Model):
    def __init__(self, *args, config_env, num_envs, **kwargs):
        self.envs, dim_obs, dim_act = make_envs(config_env, num_envs)[:3]
        # # max_episode_steps = envs.max_episode_steps
        # dim_obs = self.envs.single_observation_space.shape[0]
        # dim_act = self.envs.single_action_space.shape[0]
        super().__init__(dim_input=dim_act, dim_output=dim_obs, device="cpu")

        self.num_members = 0

    def reset(self) -> Dict[str, torch.Tensor]:
        return self.envs.reset()

    def sample(self, act, obs, **_):
        obs = self.envs.step(act)
        return obs
