import torch
import torch.nn as nn
from torch.distributions import Categorical
import torch.nn.functional as F
import gym
import mo_gym
import numpy as np
import os
import argparse
import math
import random
from tensorboardX import SummaryWriter

USE_MINI_MAP = False
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
MINI_MAP = np.array(
    [[0,    0,    0,   0,   0,  0],
     [0.7,  0,    0,   0,   0,  0],
     [-10,  8.2,  0,   0,   0,  0],
     [-10, -10, 11.5,  0,   0,  0],
     [-10, -10, -10, 14.0, 15.1, 16.1],
     [0,    0,    0,   0,   0,  0]]
)
DEFAULT_MAP = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0.7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [-10, 8.2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [-10, -10, 11.5, 0, 0, 0, 0, 0, 0, 0, 0], [-10, -10, -10, 14.0, 15.1, 16.1, 0, 0, 0, 0, 0], [-10, -10, -10, -10, -10, -10, 0, 0, 0, 0, 0],
						[-10, -10, -10, -10, -10, -10, 0, 0, 0, 0, 0], [-10, -10, -10, -10, -10, -10, 19.6, 20.3, 0, 0, 0], [-10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0], [-10, -10, -10, -10, -10, -10, -10, -10, 22.4, 0, 0], [-10, -10, -10, -10, -10, -10, -10, -10, -10, 23.7, 0]])

MAP = MINI_MAP if USE_MINI_MAP else DEFAULT_MAP
min_step_to_treasure = np.array([1, 3, 5, 7, 8, 9]) if USE_MINI_MAP else np.array([
    1, 3, 5, 7, 8, 9, 13, 14, 17, 19])
vertical_distance = np.array([1, 2, 3, 4, 4, 4]) if USE_MINI_MAP else np.array(
    [1, 2, 3, 4, 4, 4, 7, 7, 9, 10])
treasure_value = np.array([0.7, 8.2, 11.5, 14.0, 15.1, 16.1]) if USE_MINI_MAP else np.array(
    [0.7, 8.2, 11.5, 14.0, 15.1, 16.1, 19.6, 20.3, 22.4, 23.7])
TERMINAL_STATES = [[1, 0], [2, 1], [3, 2], [4, 3], [4, 4], [4, 5]] if USE_MINI_MAP else [
    [1, 0], [2, 1], [3, 2], [4, 3], [4, 4], [4, 5], [7, 6], [7, 7], [9, 8], [10, 9]]
VIRTUAL_TERMINAL_STATES = [[5.0, 0.0], [5.0, 1.0], [5.0, 2.0], [5.0, 3.0], [5.0, 4.0], [5.0, 5.0]] if USE_MINI_MAP else [
    [11.0, 0.0], [11.0, 1.0], [11.0, 2.0], [11.0, 3.0], [11.0, 4.0], [11.0, 5.0], [11.0, 6.0], [11.0, 7.0], [11.0, 8.0], [11.0, 9.0]]

parser = argparse.ArgumentParser(description='PyTorch PPO for continuous controlling')
parser.add_argument('--gpus', default=1, type=int, help='number of gpu')
parser.add_argument('--env', type=str, default='deep-sea-treasure-v0', help='continuous env')
parser.add_argument('--render', default=False, action='store_true', help='Render?')
parser.add_argument('--solved_reward', type=float, default=200, help='stop training if avg_reward > solved_reward')
parser.add_argument('--print_interval', type=int, default=100, help='how many episodes to print the results out')
parser.add_argument('--save_interval', type=int, default=100, help='how many episodes to save a checkpoint')
parser.add_argument('--max_episodes', type=int, default=100)
parser.add_argument('--max_expert_episodes', type=int, default=1000)
parser.add_argument('--max_timesteps', type=int, default=100, help='maxium timesteps in one episode')
parser.add_argument('--update_timesteps', type=int, default=200, help='how many timesteps to update the policy')
parser.add_argument('--K_epochs', type=int, default=4, help='update the policy for how long time everytime')
parser.add_argument('--eps_clip', type=float, default=0.05, help='epsilon for p/q clipped')
parser.add_argument('--gamma', type=float, default=0.99, help='discount factor')
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--seed', type=int, default=0, help='random seed to use')
parser.add_argument('--ckpt_folder', default='checkpoints', help='Location to save checkpoint models')
parser.add_argument('--tb', default=False, action='store_true', help='Use tensorboardX?')
parser.add_argument('--log_folder', default='./logs', help='Location to save logs')
parser.add_argument('--mode', default='train', help='choose train or test')
parser.add_argument('--restore', default=True, action='store_true', help='Restore and go on training?')
opt = parser.parse_args()

device = torch.device("cpu")


class Memory:  # collected from old policy

	def __init__(self):
		self.states = []
		self.actions = []
		self.rewards = []
		self.is_terminals = []
		self.logprobs = []
		self.state_values = []

	def clear_memory(self):
		del self.states[:]
		del self.actions[:]
		del self.rewards[:]
		del self.is_terminals[:]
		del self.logprobs[:]
		del self.state_values[:]


class ActorCritic(nn.Module):

	def __init__(self, state_dim, action_dim):
		super(ActorCritic, self).__init__()

		self.actor = nn.Sequential(
			nn.Linear(state_dim, 64),
			nn.Tanh(),
			nn.Linear(64, 64),
			nn.Tanh(),
			nn.Linear(64, action_dim),
			nn.Softmax(dim=-1)  # For discrete actions, we use softmax policy
		)

		self.critic = nn.Sequential(nn.Linear(state_dim, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 1))

	def act(self, state, memory):  # state (1,8)
		state_onehot = torch.LongTensor([state_encode(state)])
		state_onehot = F.one_hot(state_onehot, num_classes= 36 if USE_MINI_MAP else 132)
		state_onehot = state_onehot.type(torch.FloatTensor).reshape(1, -1).to(device)

		action_probs = self.actor(state_onehot)  # (1,4)
		dist = Categorical(action_probs)  # distribution func: sample an action (return the corresponding index) according to the probs
		if state.squeeze().tolist() in TERMINAL_STATES or state.squeeze().tolist() in VIRTUAL_TERMINAL_STATES:
			action = torch.LongTensor([1]).to(torch.device('cpu'))
		else:
			action = dist.sample().to(torch.device('cpu'))

		action_logprob = dist.log_prob(action)  # (1,)
		state_value = self.critic(state_onehot)


		memory.states.append(state_onehot)
		memory.actions.append(action)
		memory.logprobs.append(action_logprob)
		memory.state_values.append(state_value)
		return action.item()  # convert to scalar

	def evaluate(self, state, action):  # state (2000, 8); action (2000, 4)
		state_value = self.critic(state)  # (2000, 1)

		# to calculate action score(logprobs) and distribution entropy
		action_probs = self.actor(state)  # (2000,4)
		dist = Categorical(action_probs)
		action_logprobs = dist.log_prob(action)  # (2000, 1)
		dist_entropy = dist.entropy()

		return action_logprobs, torch.squeeze(state_value), dist_entropy


class PPO:

	def __init__(self, state_dim, action_dim, lr, betas, gamma, K_epochs, eps_clip, restore=False, ckpt=None, first=True):
		self.lr = lr
		self.betas = betas
		self.gamma = gamma
		self.eps_clip = eps_clip
		self.K_epochs = K_epochs

		# current policy
		self.policy = ActorCritic(state_dim, action_dim).to(device)
		if restore and not first:
			pretained_model = torch.load(ckpt, map_location=lambda storage, loc: storage)
			self.policy.load_state_dict(pretained_model)
		self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)

		# old policy: initialize old policy with current policy's parameter
		self.old_policy = ActorCritic(state_dim, action_dim).to(device)
		self.old_policy.load_state_dict(self.policy.state_dict())

		self.MSE_loss = nn.MSELoss()  # to calculate critic loss

	def select_action(self, state, memory):
		# state = torch.FloatTensor(state.reshape(1, -1)).to(device)  # flatten the state
		return self.old_policy.act(state, memory)

	def update(self, memory):
		# Monte Carlo estimation of rewards
		rewards = []
		discounted_reward = 0
		for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
			if is_terminal:
				discounted_reward = 0
			discounted_reward = reward + self.gamma * discounted_reward
			rewards.insert(0, discounted_reward)

		# Normalize rewards
		rewards = torch.FloatTensor(rewards).to(device)
		rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

		# convert list to tensor
		old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach()
		old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach()
		old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach()
		old_state_values = torch.squeeze(torch.stack(memory.state_values)).to(device).detach()

		# advantages = rewards - old_state_values.detach()

		# Train policy for K epochs: sampling and updating
		for _ in range(self.K_epochs):
			# Evaluate old actions and values using current policy
			logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

			# Importance ratio: p/q
			ratios = torch.exp(logprobs - old_logprobs.detach())

			# Advantages
			advantages = rewards - state_values.detach()  # old states' rewards - old states' value( evaluated by current policy)

			# Actor loss using Surrogate loss
			surr1 = ratios * advantages
			surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
			actor_loss = -torch.min(surr1, surr2)

			# Critic loss: critic loss - entropy
			critic_loss = 0.5 * self.MSE_loss(rewards, state_values) - 0.04 * dist_entropy

			# Total loss
			loss = actor_loss + critic_loss

			# Backward gradients
			self.optimizer.zero_grad()
			loss.mean().backward()
			self.optimizer.step()

		# Copy new weights to old_policy
		self.old_policy.load_state_dict(self.policy.state_dict())


def state_encode(state):
	ele_per_row = 6 if USE_MINI_MAP else 11
	return state[0] * ele_per_row + state[1]


def state_decode(encoded_state):
	ele_per_row = 6 if USE_MINI_MAP else 11
	state = np.array([math.floor(encoded_state / ele_per_row), encoded_state % ele_per_row])
	state = torch.FloatTensor(state.reshape(1, -1)).to(device)
	return state


def expert_select_action(state, preference, memory):
	if state.squeeze().tolist() in TERMINAL_STATES or state.squeeze().tolist() in VIRTUAL_TERMINAL_STATES:
		action = torch.LongTensor([1])
	else:
		returns = preference[0] * treasure_value - preference[1] * min_step_to_treasure
		down = (vertical_distance[state[1]] - state[0]).item()
		right = np.argmax(returns[state[1]:]).item()
		if random.random() < down / (down + right) and (right == 0 or down > 1):
			action = torch.LongTensor([1])  # down
		else:
			action = torch.LongTensor([3])  # right
	memory.states.append(torch.FloatTensor(state.reshape(1, -1)))
	memory.actions.append(action)
	return action.item()


def expert_collect_trajectories(preference, num_episodes):
	env_name = opt.env
	env = mo_gym.make(env_name)
	memory = Memory()
	for i_episode in range(1, num_episodes + 1):
		state = env.reset()[0]
		for t in range(opt.max_timesteps):
			action = expert_select_action(state, preference, memory)
			state, _, done, _, _ = env.step(action)
			if t == opt.max_timesteps-1:
				memory.is_terminals.append(True)
			else:
				memory.is_terminals.append(False)
			# if done:
			#     memory.states.append(torch.FloatTensor(state.reshape(1, -1)))
			#     memory.actions.append(None)
			#     memory.is_terminals.append(done)
			#     break
	return memory


def collect_trajectories(ppo, num_episodes):
	env_name = opt.env
	env = mo_gym.make(env_name)
	memory = Memory()
	dump_memory = Memory()
	for i_episode in range(1, num_episodes + 1):
		state = env.reset()[0]
		for t in range(opt.max_timesteps):
			action = ppo.select_action(state, dump_memory)
			memory.states.append(torch.FloatTensor(state.reshape(1, -1)))
			memory.actions.append(torch.LongTensor([action]))
			state, _, done, _, _ = env.step(action)
			if t == opt.max_timesteps-1:
				memory.is_terminals.append(True)
			else:
				memory.is_terminals.append(False)
			# if done:
			#     memory.states.append(torch.FloatTensor(state.reshape(1, -1)))
			#     memory.actions.append(None)
			#     memory.is_terminals.append(done)
			#     break
	return memory


def train(env_name, env, state_dim, action_dim, reward_signal, preference, render, solved_reward, max_episodes, max_timesteps, update_timestep, K_epochs, eps_clip, gamma, lr, betas, ckpt_folder, restore, tb=False, print_interval=10, save_interval=100, first=True):

	ckpt = ckpt_folder + '_' + env_name + '.pth'
	if restore:
		print('Load checkpoint from {}'.format(ckpt))

	memory = Memory()

	ppo = PPO(state_dim, action_dim, lr, betas, gamma, K_epochs, eps_clip, restore=restore, ckpt=ckpt, first=first)

	best_reward, running_reward, episode_reward, avg_length, time_step = float("-inf"), 0, 0, 0, 0

	# training loop
	for i_episode in range(1, max_episodes + 1):
		state = env.reset()[0]
		for t in range(max_timesteps):
			time_step += 1

			# Run old policy
			action = ppo.select_action(state, memory)
			if len(reward_signal) > 0:
				reward = torch.dot(reward_signal[state_encode(state), action], preference.to(torch.device('cpu'))).item()
				state, _, done, _, _ = env.step(action)
			else:
				state, reward, done, _, _ = env.step(action)

			memory.rewards.append(reward)
			if len(reward_signal) > 0:
				if t ==  max_timesteps-1:
					memory.is_terminals.append(True)
				else:
					memory.is_terminals.append(False)
			else:
				memory.is_terminals.append(done)

			if time_step % update_timestep == 0:
				ppo.update(memory)
				memory.clear_memory()
				time_step = 0

			episode_reward += reward
			if render:
				env.render()

			# if done:
			#     break
		avg_length += t
		running_reward = 0.9*running_reward + 0.1*episode_reward
		if running_reward > best_reward:
			# print("########## Solved! ##########")
			#torch.save(ppo.policy.state_dict(), ckpt_folder + '_' + env_name + '.pth')
			#print('Best reward: ', running_reward, '. Save a checkpoint!')
			best_reward = running_reward
		episode_reward = 0
		if i_episode % 100 == 0:
			avg_length = int(avg_length / 100)
			print('Episode {} \t Avg length: {} \t Running reward: {}'.format(
                i_episode, avg_length, running_reward))
		# if i_episode % save_interval == 0:
		# 	torch.save(ppo.policy.state_dict(), ckpt_folder + '_' + env_name + '.pth')
		# 	print('Save a checkpoint!')

		# if i_episode % print_interval == 0:
		# 	avg_length = int(avg_length / print_interval)
		# 	running_reward = int((running_reward / print_interval))

		# 	print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))

		# 	if tb:
		# 		writer.add_scalar('scalar/reward', running_reward, i_episode)
		# 		writer.add_scalar('scalar/length', avg_length, i_episode)

		# 	running_reward, avg_length = 0, 0
	return ppo


def test(env_name, env, state_dim, action_dim, render, K_epochs, eps_clip, gamma, lr, betas, ckpt, test_episodes, mute, max_timesteps):

	if not mute:
		print('Load checkpoint from {}'.format(ckpt))


	memory = Memory()

	ppo = PPO(state_dim, action_dim, lr, betas, gamma, K_epochs, eps_clip, restore=True, ckpt=ckpt, first=False)

	episode_reward, time_step = 0, 0
	avg_episode_reward, avg_length = 0, 0

	# test
	for i_episode in range(1, test_episodes + 1):
		state = env.reset()[0]
		for t in range(max_timesteps):
			time_step += 1

			# Run old policy
			action = ppo.select_action(state, memory)

			state, reward, done, _, _ = env.step(action)

			episode_reward += reward

			if render:
				env.render()

			if done:
				break
				
		if not mute:
			print('Episode {} \t Length: {} \t Reward: {}'.format(i_episode, time_step, episode_reward))
		avg_episode_reward += episode_reward
		avg_length += time_step
		memory.clear_memory()
		time_step, episode_reward = 0, 0

	if not mute:
		print('Test {} episodes DONE!'.format(test_episodes))
		print('Avg episode reward: {} | Avg length: {}'.format(avg_episode_reward / test_episodes, avg_length / test_episodes))
	return avg_episode_reward / test_episodes, avg_length / test_episodes


def expert_test(env_name, env, preference, render, test_episodes, mute):
	memory = Memory()
	episode_reward, time_step = 0, 0
	avg_episode_reward, avg_length = 0, 0

	# test
	for i_episode in range(1, test_episodes + 1):
		state = env.reset()[0]
		while True:
			time_step += 1

			# Run old policy
			action = expert_select_action(state, preference, memory)

			state, reward, done, _, _ = env.step(action)

			episode_reward += reward

			if render:
				env.render()

			if done:
				if not mute:
					print('Episode {} \t Length: {} \t Reward: {}'.format(i_episode, time_step, episode_reward))
				avg_episode_reward += episode_reward
				avg_length += time_step
				memory.clear_memory()
				time_step, episode_reward = 0, 0
				break
	if not mute:
		print('Test {} episodes DONE!'.format(test_episodes))
		print('Avg episode reward: {} | Avg length: {}'.format(avg_episode_reward / test_episodes, avg_length / test_episodes))
	return avg_episode_reward, avg_length


def runPPO(reward_signal, preference, id, first=True):
	torch.manual_seed(opt.seed)
	np.random.seed(opt.seed)

	env_name = opt.env
	env = mo_gym.make(env_name)
	state_dim = 36 if USE_MINI_MAP else 132
	# state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.n

	if len(reward_signal) > 0:
		reward_signal = reward_signal.reshape(-1, action_dim, preference.shape[0])
	else:
		env = mo_gym.LinearReward(env, weight=np.array(preference))
	if not os.path.exists('./checkpoints/agent'):
		os.mkdir('./checkpoints/agent')
	ppo = train(env_name,
				env,
				state_dim,
				action_dim,
				reward_signal,
				preference,
				render=opt.render,
				solved_reward=opt.solved_reward,
				max_episodes=opt.max_episodes,
				max_timesteps=opt.max_timesteps,
				update_timestep=opt.update_timesteps,
				K_epochs=opt.K_epochs,
				eps_clip=opt.eps_clip,
				gamma=opt.gamma,
				lr=opt.lr,
				betas=[0.9, 0.999],
				ckpt_folder='./checkpoints/agent/agent' + str(id) + '_' + opt.ckpt_folder,
				restore=opt.restore,
				tb=opt.tb,
				print_interval=opt.print_interval,
				save_interval=opt.save_interval,
				first=first)
	if opt.tb:
		writer.close()
	return ppo


def testPPO(preference, id, expert, mute, ckpt):
	torch.manual_seed(opt.seed)
	np.random.seed(opt.seed)

	env_name = opt.env
	env = mo_gym.make(env_name, render_mode="human" if opt.render else None)
	env = mo_gym.LinearReward(env, weight=np.array(preference))
	state_dim = 36 if USE_MINI_MAP else 132
	# state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.n
	if not expert:
		rew, lang = test(env_name, env, state_dim, action_dim, render=opt.render, K_epochs=opt.K_epochs, eps_clip=opt.eps_clip, gamma=opt.gamma, lr=opt.lr, betas=[0.9, 0.990], ckpt=ckpt, test_episodes=100, mute=mute, max_timesteps=100)
	else:
		rew, lang = expert_test(env_name, env, preference, render=opt.render, test_episodes=100, mute=mute)

	if opt.tb:
		writer.close()

	return rew, lang
