import numpy as np
import torch
import pdb


class StepSampler(object):

    def __init__(self, env, max_traj_length=1000, hidden_dims=None):
        self.max_traj_length = max_traj_length
        self._env = env
        self._traj_steps = 0
        self._current_observation = self.env.reset()
        self.hidden_dims = hidden_dims

    def sample(self, policy, n_steps, deterministic=False, replay_buffer=None, joint_noise_std=0.):
        observations = []
        actions = []
        rewards = []
        next_observations = []
        dones = []

        for _ in range(n_steps):
            self._traj_steps += 1
            observation = self._current_observation

            if self.hidden_dims is not None:
                observation[self.hidden_dims] = 0.0

            #TODO sample actions from current policy
            action = policy(
                np.expand_dims(observation, 0), deterministic=deterministic
            )[0, :]

            # if joint_noise_std > 0.:
            #     # normal distribution
            #     next_observation, reward, done, _ = self.env.step(action + np.random.randn(action.shape[0],) * joint_noise_std)
            # else:
            next_observation, reward, done, _ = self.env.step(action)

            if self.hidden_dims is not None:
                next_observation[self.hidden_dims] = 0.0

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            next_observations.append(next_observation)

            # add samples derived from current policy to replay buffer
            if replay_buffer is not None:
                replay_buffer.append(
                    observation, action, reward, next_observation, done
                )

            self._current_observation = next_observation

            if done or self._traj_steps >= self.max_traj_length:
                self._traj_steps = 0
                self._current_observation = self.env.reset()

        return dict(
            observations=np.array(observations, dtype=np.float32),
            actions=np.array(actions, dtype=np.float32),
            rewards=np.array(rewards, dtype=np.float32),
            next_observations=np.array(next_observations, dtype=np.float32),
            dones=np.array(dones, dtype=np.float32),
        )

    @property
    def env(self):
        return self._env

# with dones as a trajectory end indicator, we can use this sampler to sample trajectories
class TrajSampler(object):

    def __init__(self, env, max_traj_length=1000, hidden_dims=None):
        self.max_traj_length = max_traj_length
        self._env = env
        self.hidden_dims = hidden_dims

    def sample(self, policy, n_trajs, deterministic=False, replay_buffer=None):
        trajs = []
        for _ in range(n_trajs):
            observations = []
            actions = []
            rewards = []
            next_observations = []
            dones = []

            observation = self.env.reset()
            if self.hidden_dims is not None:
                observation[self.hidden_dims] = 0.0

            for _ in range(self.max_traj_length):
                action = policy(
                    np.expand_dims(observation, 0), deterministic=deterministic
                )[0, :]
                next_observation, reward, done, _ = self.env.step(action)

                if self.hidden_dims is not None:
                    next_observation[self.hidden_dims] = 0.0
                
                observations.append(observation)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)
                next_observations.append(next_observation)

                if replay_buffer is not None:
                    replay_buffer.add_sample(
                        observation, action, reward, next_observation, done
                    )

                observation = next_observation

                if done:
                    break

            trajs.append(dict(
                observations=np.array(observations, dtype=np.float32),
                actions=np.array(actions, dtype=np.float32),
                rewards=np.array(rewards, dtype=np.float32),
                next_observations=np.array(next_observations, dtype=np.float32),
                dones=np.array(dones, dtype=np.float32),
            ))

        return trajs

    @property
    def env(self):
        return self._env
