import torch
import numpy as np

from baselines.common.vec_env import VecEnvWrapper

from plr.level_sampler import LevelSampler


class VecPLRWrapper(VecEnvWrapper):
	def __init__(self, 
		venv,
		device='cpu', 
		force_uniform_full_dist=False, 
		level_sampler=None,
		seeds=None,
		level_replay_kwargs=None):
		super(VecPLRWrapper, self).__init__(venv)
		
		self.venv = venv
		self.device = device

		self.level_sampler = None

		if force_uniform_full_dist:
			self.level_sampler = None
		elif level_sampler is not None:
			self.level_sampler = level_sampler
		elif seeds is not None:
			self.level_sampler = LevelSampler(
				seeds=seeds, 
				num_actors=self.venv.num_envs, 
				**level_replay_kwargs)

		self.seeds = torch.zeros(self.venv.num_envs)

	@property
	def raw_venv(self):
		rvenv = self.venv
		while hasattr(rvenv, 'venv'):
			rvenv = rvenv.venv
		return rvenv

	def reset(self):
		if self.level_sampler:
			seeds = torch.zeros(self.venv.num_envs, dtype=torch.int)
			for e in range(self.venv.num_envs):
				seed = self.level_sampler.sample('sequential')
				seeds[e] = seed
				self.venv.seed(seed,e)

		obs = self.venv.reset()

		return obs

	def step_async(self, actions):
		if isinstance(actions, torch.LongTensor) or len(actions.shape) > 1:
			# Squeeze the dimension for discrete actions
			actions = actions.squeeze(1)
		actions = actions.cpu().numpy()
		self.venv.step_async(actions)

	def step_wait(self):
		obs, reward, done, infos = self.venv.step_wait()

		# reset environment here
		for i, info in enumerate(infos):
			if 'episode' in info:
				if self.level_sampler:
					seed = self.level_sampler.sample()
				else:
					seed = np.random.randint(1,1e12)
				obs[i] = self.venv.seed(seed, i) # seed resets the corresponding level

		return obs, reward, done, infos