import gym
import numpy as np
import argparse

class NoisyObsWrapper(gym.ObservationWrapper):
	def __init__(self, env, sigma):
		super().__init__(env)
		self.sigma = sigma
	def observation(self, obs):
		return obs + self.sigma*np.random.standard_normal(size=obs.shape)

class FinishEarlyWrapper(gym.Wrapper):
	def reset(self, **kwargs):
		return self.env.reset(**kwargs)

	def step(self, action):
		observation, reward, done, info = self.env.step(action)
		if (reward !=0):
			done = True
		return observation, reward, done, info


from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.dqn.policies import CnnPolicy 
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import EvalCallback
from gym.wrappers import FrameStack, AtariPreprocessing
parser = argparse.ArgumentParser(description='Pong DQN example')
parser.add_argument('--sigma', type=float, default=0.0, metavar='N',
					help='How much noise to smooth observations')
args = parser.parse_args()


if __name__ == '__main__':
	# Multiprocess environment
	sigma = args.sigma * 255.
	env =  FrameStack(NoisyObsWrapper(FinishEarlyWrapper(AtariPreprocessing(gym.make("PongNoFrameskip-v0"))), sigma), 4) 
	eval_env =  FrameStack(NoisyObsWrapper(FinishEarlyWrapper(AtariPreprocessing(gym.make("PongNoFrameskip-v0"))), sigma), 4) 
	eval_callback = EvalCallback(eval_env, best_model_save_path="pong_1r_sigma_"+ str(sigma),
		log_path="./logs_pong/"+ str(sigma)+'/', eval_freq=100000, n_eval_episodes=100)

	policy_kwargs = {}
	model = DQN(CnnPolicy, env,
				learning_rate=0.0001,
				buffer_size=10000,
				learning_starts=100000,
				exploration_fraction=0.1,
				target_update_interval=1000,
				batch_size=32,
				verbose=1,
				train_freq= 4,
				gradient_steps=1,
				exploration_final_eps= 0.01,
				policy_kwargs={},
				optimize_memory_usage = True,
				tensorboard_log="./logs_pong_1r/"+ str(sigma)+'/')
	model.learn(total_timesteps=10000000, callback=eval_callback)
