import numpy as np

from .base_sampler import BaseSampler


class VectorizedSampler(BaseSampler):
	def __init__(self, num_envs=1, **kwargs):
		super(VectorizedSampler, self).__init__(**kwargs)

		self._path_length = [0 for _ in range(num_envs)]
		self._path_return = [0 for _ in range(num_envs)]
		self._infos = [[] for _ in range(num_envs)]
		self._last_path_return = 0
		self._max_path_return = -np.inf
		self._n_episodes = 0
		self._current_observation = [None for _ in range(num_envs)]
		self._total_samples = 0
		self.num_envs = num_envs

	def initialize(self, env, policy, pool):
		assert hasattr(env._env.unwrapped, "mujoco_envs")
		self.wrap_envs = env
		self.env = env._env.unwrapped.mujoco_envs
		self.policy = policy
		self.pool = pool

	def sample(self):
		next_observations, rewards, terminals, infos = [], [], [], []
		assert len(self.env) == self.num_envs
		for i, env in enumerate(self.env):
			if self._current_observation[i] is None:
				self._current_observation[i] = self.env[i].reset()

			if hasattr(self.policy, '__len__'):
				num_tasks = len(self.policy)
				assert num_tasks == self.num_envs
				action = self.policy[i].actions_np([
						self._current_observation[i][None]
				])[0]
			else:
				action = self.policy.actions_np([
						self._current_observation[i][None]
				])[0]

			next_observation, reward, terminal, info = self.env[i].step(action)
			info.pop('goal', None)
			self._path_length[i] += 1
			self._path_return[i] += reward
			self._infos[i].append(info)
			self._total_samples += 1
			next_observations.append(next_observation)
			rewards.append(reward)
			terminals.append(terminal)
			infos.append(info)

			self.pool.add_sample(
				observations=self._current_observation[i],
				actions=action,
				rewards=reward,
				terminals=terminal,
				next_observations=next_observation)

			if terminal or self._path_length[i] >= self._max_path_length:
				last_path = self.pool.last_n_batch(
					self._path_length[i],
					observation_keys=getattr(self.env[i], 'observation_keys', None))
				last_path.update({'infos': self._infos[i]})
				self._last_n_paths.appendleft(last_path)

				if hasattr(self.policy, '__len__'):
					num_tasks = len(self.policy)
					self.policy[i].reset()
				self._current_observation[i] = self.env[i].reset()

				self._max_path_return = max(self._max_path_return,
											self._path_return[i])
				self._last_path_return = self._path_return[i]

				self._path_length[i] = 0
				self._path_return[i] = 0
				self._infos[i] = []

				self._n_episodes += 1

			else:
				self._current_observation[i] = next_observation
		return self._current_observation, rewards, terminals, infos

	def random_batch(self, batch_size=None, **kwargs):
		batch_size = batch_size or self._batch_size
		observation_keys = getattr(self.wrap_envs, 'observation_keys', None)
		if self.wrap_envs.sample_probs is not None:
			kwargs['sample_probs'] = self.wrap_envs.sample_probs

		# return self.pool.random_batch(
		#     batch_size, observation_keys=observation_keys, **kwargs)
		# num_tasks = len(self.policy)
		return self.pool.random_batch(
			batch_size, observation_keys=observation_keys, **kwargs)

	def get_diagnostics(self):
		diagnostics = super(VectorizedSampler, self).get_diagnostics()
		diagnostics.update({
			'max-path-return': self._max_path_return,
			'last-path-return': self._last_path_return,
			'episodes': self._n_episodes,
			'total-samples': self._total_samples,
		})

		return diagnostics
