import os
from prompta.utils.import_helper import import_pipeline_module
from prompta.env import *


class BaseActor:

    def __init__(self, cfg) -> None:
        self.whole_cfg = cfg
        self.cfg = self.whole_cfg.actor
        self.exp_dir = os.path.join(os.getcwd(), 'experiments', self.whole_cfg.common.experiment_name, )
        self.setup_agent()
        self.setup_env()

    def setup_agent(self):
        agent_cls = import_pipeline_module(self.whole_cfg.agent.pipeline, "Agent")
        self.agent = agent_cls(self.whole_cfg)

    def setup_env(self):
        env_func = eval(self.whole_cfg.env.env_func)
        self.env = env_func(self.whole_cfg)

    def collect_episodes(self, num_episodes):
        
        for episode in range(num_episodes):
            obs, info = self.env.reset()
            self.agent.reset()
            done, rew = False, 0
            while not done:
                action = self.agent.step(obs, rew, done, info)
                obs, rew, done, info = self.env.step(action)
                print(rew)
        print(info['score'])
        
    


