import torch.backends.cudnn as cudnn

from Envs.vec_env.envs import make_vec_envs
from models.ppo.model import Policy
from models.ppo.utils import get_vec_normalize
from models.ppo.storage import RolloutStorage
from models.ppo import algo
import time
from collections import deque
import numpy as np
import pandas as pd
import torch
import os
from shutil import copyfile
from supervisedImg import train_with_oneHot

from cfg import main_config, gym_register

import functools


def rgetattr(oj, atr, *args):
	def _getattr(oj, atr):
		return getattr(oj, atr, *args)
	return functools.reduce(_getattr, [oj] + atr.split('.'))


if __name__ == '__main__':
	config=main_config()
	gym_register(config)
	if config.RLManualControl: # used for debugging the env
		envs = make_vec_envs(env_name=config.RLEnvName,
							 seed=0,
							 num_processes=1,
							 gamma=None,
							 device=None,
							 randomCollect=False,
							 config=config)

		observation = envs.reset()
		for episode in range(10):

			for i in range(config.RLEnvMaxSteps):
				print('step:', i)

				print('step reward', envs.venv.origStepReward)
				action = torch.zeros(config.RLActionDim)  # dummy action. True action is decided in env
				observation, _, _, _ = envs.step(action)
				envs.render()

	else:
		device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
		print("Using device:", device)
		cudnn.benchmark = True


		if config.RLTrain:
			torch.set_num_threads(1)
			torch.manual_seed(config.RLEnvSeed)
			torch.cuda.manual_seed_all(config.RLEnvSeed)

			if not os.path.exists(config.RLModelSaveDir):
				os.makedirs(config.RLModelSaveDir)
			copyfile(os.path.join('..','Envs', config.envFolder, 'RSI1', 'config.py'),
					 os.path.join(config.RLModelSaveDir, 'config.py'))

			envs = make_vec_envs(env_name=config.RLEnvName,
								 seed=config.RLEnvSeed,
								 num_processes=config.RLNumEnvs,
								 gamma=config.RLGamma,
								 device=device,
								 randomCollect=False,
								 config=config)

			actor_critic = Policy(
				envs.venv.observation_space.spaces,
				envs.action_space,
				config=config,
				base=config.RLPolicyBase,
				base_kwargs={'recurrent': config.RLRecurrentPolicy,
							 'recurrentInputSize': config.RLRecurrentInputSize,
							 'recurrentSize': config.RLRecurrentSize,
							 'actionHiddenSize': config.RLActionHiddenSize
							 })
			actor_critic.to(device)

			# load pretrained cnn
			if config.usePretrainedModel:
				x=torch.load(config.pretrainedModelLoadDir)
				actor_critic.load_state_dict(x, strict=False)
				if config.freezePretrainedModel:
					for layers in x:
						o=rgetattr(actor_critic,layers)
						o.requires_grad = False

				print('Loaded pretrained weights from', config.pretrainedModelLoadDir)

			if config.RLModelFineTune:
				print("Load the weights from", config.RLModelLoadDir)
				actor_critic.load_state_dict(torch.load(config.RLModelLoadDir))

			agent = algo.PPO(
				actor_critic,
				config.ppoClipParam,
				config.ppoEpoch,
				config.ppoNumMiniBatch,
				config.ppoValueLossCoef,
				config.ppoEntropyCoef,
				lr=config.RLLr,
				eps=config.RLEps,
				max_grad_norm=config.RLMaxGradNorm,
				config=config)

			rollouts = RolloutStorage(config.ppoNumSteps, config.RLNumEnvs,
									  envs.venv.observation_space.spaces, envs.action_space,
									  actor_critic.recurrent_hidden_state_size, config=config)

			env_rewards = np.zeros([config.RLNumEnvs,])
			episode_rewards = deque(maxlen=10)


			print('Begin RL training')
			obs = envs.reset()

			if isinstance(obs, dict):
				for key in obs:
					rollouts.obs[key][0].copy_(obs[key])
			else:
				rollouts.obs[0].copy_(obs)
			rollouts.to(device)

			start = time.time()
			num_updates = int(
				config.RLTotalSteps) // config.ppoNumSteps // config.RLNumEnvs
			for j in range(0, num_updates):
				for step in range(config.ppoNumSteps):
					# Sample actions
					with torch.no_grad():
						rollouts_obs = {}
						for key in rollouts.obs:
							rollouts_obs[key] = rollouts.obs[key][step]
						value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
							rollouts_obs, rollouts.recurrent_hidden_states[step],
							rollouts.masks[step])

					obs, reward, done, infos = envs.step(action)

					if config.render:
						print('step reward', envs.venv.origStepReward)
						time.sleep(0.5)
						envs.render()

					env_rewards=env_rewards+envs.venv.origStepReward
					if any(done):
						idx=np.where(done==True)[0]
						for index in idx:
							episode_rewards.append(env_rewards[index])
							env_rewards[index]=0.

					# If done then clean the history of observations.
					masks = torch.FloatTensor(
						[[0.0] if done_ else [1.0] for done_ in done])
					bad_masks = torch.FloatTensor(
						[[0.0] if 'bad_transition' in info.keys() else [1.0]
						 for info in infos])
					rollouts.insert(obs, recurrent_hidden_states, action,
									action_log_prob, value, reward, masks, bad_masks)

				with torch.no_grad():
					rollouts_obs = {}
					for key in rollouts.obs:
						rollouts_obs[key] = rollouts.obs[key][-1]
					next_value = actor_critic.get_value(
						rollouts_obs, rollouts.recurrent_hidden_states[-1],
						rollouts.masks[-1]).detach()

				rollouts.compute_returns(next_value, config.ppoUseGAE, config.RLGamma,
										 config.ppoGAELambda, config.RLUseProperTimeLimits) 

				value_loss, action_loss, dist_entropy, inSightLoss, exiLoss, soundAuxLoss = agent.update(rollouts)

				rollouts.after_update()

				# save for every interval-th episode or for the last epoch
				if (j % config.RLModelSaveInterval == 0
					or j == num_updates - 1) and config.RLModelSaveDir != "":
					save_path = config.RLModelSaveDir

					if not os.path.exists(save_path):
						os.makedirs(save_path)
					torch.save(actor_critic.state_dict(), os.path.join(save_path, '%.5i'%j + ".pt"))

				if j % config.RLLogInterval == 0 and len(episode_rewards) > 1:
					total_num_steps = (j + 1) * config.RLNumEnvs * config.ppoNumSteps
					end = time.time()
					print(
						"Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
						"inSight: {}, exi: {}, soundAux: {}\n"
							.format(j, total_num_steps,
									int(total_num_steps / (end - start)),
									len(episode_rewards), np.mean(episode_rewards),
									np.median(episode_rewards), np.min(episode_rewards),
									np.max(episode_rewards), inSightLoss, exiLoss, soundAuxLoss, dist_entropy, value_loss,
									action_loss))

					df = pd.DataFrame({'misc/nupdates': [j], 'misc/total_timesteps': [total_num_steps],
									   'fps': int(total_num_steps / (end - start)),
									   'eprewmean': [np.mean(episode_rewards)],
									   'min': np.min(episode_rewards),
									   'max': np.max(episode_rewards),
									   'loss/policy_entropy': dist_entropy, 'loss/policy_loss': action_loss,
									   'loss/value_loss': value_loss, 'loss/inSightLoss': inSightLoss,
									   'loss/exiLoss': exiLoss, 'loss/soundAux': soundAuxLoss})

					if os.path.exists(os.path.join(config.RLModelSaveDir, 'progress.csv')) and j > 20:
						df.to_csv(os.path.join(config.RLModelSaveDir, 'progress.csv'), mode='a', header=False,
								  index=False)
					else:
						df.to_csv(os.path.join(config.RLModelSaveDir, 'progress.csv'), mode='w', header=True,
								  index=False)

		else: # evaluate the policy
			num_processes=1
			eval_envs = make_vec_envs(env_name=config.RLEnvName,
									  seed=config.RLEnvSeed,
									  num_processes=num_processes,
									  gamma=None,
									  device=device,
									  randomCollect=False,
									  config=config)
			baseEnv=eval_envs.venv.unwrapped.envs[0]

			# load the trained policy
			actor_critic= Policy(
				eval_envs.venv.observation_space.spaces,
				eval_envs.action_space,
				config=config,
				base=config.RLPolicyBase,
				base_kwargs={'recurrent': config.RLRecurrentPolicy,
							 'recurrentInputSize': config.RLRecurrentInputSize,
							 'recurrentSize': config.RLRecurrentSize,
							 'actionHiddenSize': config.RLActionHiddenSize
							 })

			assert (config.RLModelLoadDir is not None)
			print("Load the weights from", config.RLModelLoadDir)
			actor_critic.load_state_dict(torch.load(config.RLModelLoadDir))

			actor_critic.eval()
			print("Weights Loaded!")
			actor_critic.to(device)


			eval_episode_rewards = []
			eval_env_rewards = 0.

			obs = eval_envs.reset()

			eval_recurrent_hidden_states = torch.zeros(
				num_processes, actor_critic.recurrent_hidden_state_size, device=device)
			eval_masks = torch.zeros(num_processes, 1, device=device)


			episode_num = baseEnv.size_per_class_cumsum[-1]

			results=[]
			goal_area_count_list = []
			objs = np.arange(config.taskNum, dtype=np.int64)
			objs=np.repeat(objs, baseEnv.size_per_class)


			while baseEnv.episodeCounter < episode_num:

				with torch.no_grad():
					_, action, _, eval_recurrent_hidden_states = actor_critic.act(
						obs,
						eval_recurrent_hidden_states,
						eval_masks,
						deterministic=True)

				# Obser reward and next obs
				obs, _, done, infos = eval_envs.step(action)
				if config.render:
					eval_envs.render()

					print('step reward', eval_envs.venv.origStepReward)
				eval_env_rewards = eval_env_rewards + eval_envs.venv.origStepReward

				eval_masks = torch.tensor(
					[[0.0] if done_ else [1.0] for done_ in done],
					dtype=torch.float32,
					device=device)

				if done:
					goal_area_count=infos[0]['goal_area_count']
					goal_area_count_list.append(goal_area_count)
					results.append(int(goal_area_count>=config.success_threshold))
					eval_episode_rewards.append(float(eval_env_rewards))
					eval_env_rewards = 0.

			# save the results
			if not config.render:
				df = pd.DataFrame({'objIdx': objs, 'goal area count': goal_area_count_list, 'rewards':eval_episode_rewards, 'results': results})
				save_path=os.path.join(os.path.dirname(config.RLModelLoadDir), 'test_'+os.path.splitext(os.path.basename(config.RLModelLoadDir))[0]+ '.csv')
				df.to_csv(save_path, mode='w', header=True, index=False)
				print('results saved to', save_path)
				print('success rate', sum(results)*1./baseEnv.size_per_class_cumsum[-1])
			eval_envs.close()
