from pymoo.factory import get_performance_indicator

import numpy as np
import torch
import gym
import argparse
import os

import utils
import TD3
import OurDDPG
import DDPG

import dst_d

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

PREF = [[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5], [0.4, 0.6], [0.3, 0.7], [0.2, 0.8],[0.1,0.9]]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def sa(policy, state, preference):
	state = torch.FloatTensor(state.reshape(1, -1)).to(device)
	preference = torch.FloatTensor(preference.reshape(1, -1)).to(device)
	return policy(state, preference).cpu().data.numpy().flatten()


def get_pref(reward_dim):
	preference = np.random.rand( reward_dim)
	preference = preference.astype(np.float32)
	preference /= preference.sum()
	'''
	PRE = [[0.9, 0.1], [0.5,0.5], [0.1,0.9]]
	preference = np.array(random.choice(PRE))
	preference = preference.astype(np.float32)
	'''
	return preference

if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	parser.add_argument("--policy", default="OurDDPG")                  # Policy name (TD3, DDPG or OurDDPG)
	#parser.add_argument("--env", default="HalfCheetah-v2")          # OpenAI gym environment name
	#parser.add_argument("--env", default="mohopper-v3")          # OpenAI gym environment name
	parser.add_argument("--env", default="dst_d-v0")          # OpenAI gym environment name
	#parser.add_argument("--env", default="dst_d-v0")          # OpenAI gym environment name
	parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
	parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used
	parser.add_argument("--eval_freq", default=1e4, type=int)       # How often (time steps) we evaluate
	parser.add_argument("--max_timesteps", default=100, type=int)   # Max time steps to run environment
	parser.add_argument("--max_episodes", default=10, type=int)   # Max time episodes to run environment
	parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
	parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
	parser.add_argument("--discount", default=0.99)                 # Discount factor
	parser.add_argument("--tau", default=0.005)                     # Target network update rate
	parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
	parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
	parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates
	parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
	parser.add_argument("--load_model", default="")                 # Model load file name, "" doesn't load, "default" uses file_name
	args = parser.parse_args()

	file_name = f"{args.policy}_{args.env}_{args.seed}"
	print("---------------------------------------")
	print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
	print("---------------------------------------")

	'''
	if not os.path.exists("./results"):
		os.makedirs("./results")
	if not os.path.exists(summary_dir):
		os.makedirs(summary_dir)
	if not os.path.exists("./results"):
		os.makedirs("./results")

	if not os.path.exists("./models"):
		os.makedirs("./models")
	if not os.path.exists(model_dir):
		os.makedirs(model_dir)
	'''
	summary_dir = os.path.join("./results", file_name)
	model_dir = os.path.join("./models",file_name)

	env = gym.make(args.env)

	# Set seeds
	env.seed(args.seed)
	env.action_space.seed(args.seed)
	torch.manual_seed(args.seed)
	np.random.seed(args.seed)
	
	state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.shape[0] 
	reward_dim = env.reward_space
	max_action = float(env.action_space.high[0])

	kwargs = {
		"state_dim": state_dim,
		"action_dim": action_dim,
		"reward_dim": reward_dim,
		"max_action": max_action,
		"discount": args.discount,
		"tau": args.tau,
	}
     
	policy = OurDDPG.Actor(state_dim = state_dim, action_dim = action_dim, reward_dim = reward_dim, max_action = max_action).to(device)
	policy.load_state_dict(torch.load(model_dir + "/30_actor"))
	state, done = env.reset(), False
	p = get_pref(reward_dim)
	episode_reward = np.zeros(reward_dim)
	episode_timesteps = 0
	episode_num = 0
	tot_reward = np.zeros((args.max_episodes, reward_dim))

	while episode_num < int(args.max_episodes):
		
		episode_timesteps += 1

		# Select action randomly or according to policy
		action = sa(policy, np.array(state), p)

		# Perform action
		next_state, reward, done, _ = env.step(action) 
		done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0


		state = next_state
		episode_reward += reward

		# Train agent after collecting sufficient data

		if done: 
			# +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
			print(f"Episode Num: {episode_num+1} Episode T: {episode_timesteps} Weight{p} Reward: {episode_reward}")
			# Reset environment
			state, done = env.reset(), False
			p = get_pref(reward_dim)
			tot_reward[episode_num] = episode_reward
			episode_reward = np.zeros(reward_dim)
			episode_timesteps = 0
			episode_num += 1 

	print(tot_reward)
	hv = get_performance_indicator("hv", ref_point=np.array([30, 0]))
	print("hv", hv.do(tot_reward))
                   
