import numpy as np

class Sampler:
    def __init__(self,
                 env,
                 agent,
                 action_dim,
                 max_step,
                 device):

        self.env = env
        self.agent = agent
        self.action_dim = action_dim
        self.max_step = max_step
        self.device = device
        self.cur_samples = 0

    def obtain_samples(self, max_samples):
        trajs = []
        cur_samples = 0

        while cur_samples < max_samples:
            traj = self.rollout()
            cur_samples += len(traj["cur_obs"])

            if cur_samples > max_samples:
                usable_len = max_samples - cur_samples

                trimmed_traj = {
                    key: traj[key][:usable_len] for key in traj
                }
                traj = trimmed_traj
                cur_samples += usable_len

            trajs.append(traj)   

        return trajs

    def rollout(self):
        _cur_obs = []
        _actions = []
        _rewards = []
        _dones = []
        _infos = []

        cur_step = 0
        obs = self.env.reset()
        done = np.zeros(1)

        while not (done or cur_step == self.max_step):
            action = self.agent.get_action(obs)
            next_obs, reward, done, info = self.env.step(action)

            reward = np.array(reward)
            done = np.array(int(done))
            _cur_obs.append(obs)
            _actions.append(action)
            _rewards.append(reward)
            _dones.append(done)
            _infos.append(info)

            obs = next_obs
            cur_step += 1

        return dict(
            cur_obs=np.array(_cur_obs),
            actions=np.array(_actions),
            rewards=np.array(_rewards),
            dones=np.array(_dones),
            infos=np.array(_infos),
        )
