from pymoo.factory import get_performance_indicator

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import numpy as np
import torch
import gym
import argparse
import os

import utils
import OurDDPG
import TD3
import DDPG

import half_cheetah_v3
import hopper_v3
import ant_v3
import walker2d_v3

import dst_d


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def compute_hv(objs, ref_point):
	x, hv = ref_point[0], 0.0
	for i in range(len(objs)):
		hv += (max(ref_point[0], objs[i][0]) - x) * (max(ref_point[1], objs[i][1]) - ref_point[1])
		x = max(ref_point[0], objs[i][0])
	return hv

def sa(policy, state, preference):
	state = torch.FloatTensor(state.reshape(1, -1)).to(device)
	preference = torch.FloatTensor(preference.reshape(1, -1)).to(device)
	return policy(state, preference).cpu().data.numpy().flatten()


def get_pref(reward_dim):
	preference = np.random.rand( reward_dim)
	preference = preference.astype(np.float32)
	preference /= preference.sum()
	'''
	PRE = [[0.9, 0.1], [0.5,0.5], [0.1,0.9]]
	preference = np.array(random.choice(PRE))
	preference = preference.astype(np.float32)
	'''
	return preference

if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	parser.add_argument("--policy", default="OurDDPG")                  # Policy name (TD3, DDPG or OurDDPG)
	#parser.add_argument("--env", default="HalfCheetah-v2")          # OpenAI gym environment name
	#parser.add_argument("--env", default="MO_hopper-v0")          # OpenAI gym environment name
	#parser.add_argument("--env", default="MO_ant-v0")          # OpenAI gym environment name
	#parser.add_argument("--env", default="MO_half_cheetah-v0")          # OpenAI gym environment name
	parser.add_argument("--env", default="dst_d-v0")          # OpenAI gym environment name
	#parser.add_argument("--env", default="MO_walker-v0")          # OpenAI gym environment name
	parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
	parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used
	parser.add_argument("--eval_freq", default=1e4, type=int)       # How often (time steps) we evaluate
	parser.add_argument("--max_timesteps", default=100, type=int)   # Max time steps to run environment
	parser.add_argument("--max_episodes", default=100, type=int)   # Max time episodes to run environment
	parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
	parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
	parser.add_argument("--discount", default=0.99)                 # Discount factor
	parser.add_argument("--tau", default=0.005)                     # Target network update rate
	parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
	parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
	parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates
	parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
	parser.add_argument("--load_model", default="")                 # Model load file name, "" doesn't load, "default" uses file_name
	args = parser.parse_args()
	tot = []
	utility = []
	hv = []
	epi_num = 1
	env_name = args.env
	if env_name == "MO_ant-v0":
		ref = (0,-3000)
	elif env_name == "MO_half_cheetah-v0":
		ref = (0,-4000)
	elif env_name == "MO_hopper-v0":
		ref = (0,-1000)
	elif env_name == "MO_walker-v0":
		ref = (0,-1000)
	elif env_name == "dst_d-v0":
		ref = (0,-50)
	else:
		print(888)
	for seed in range(1,6):
		file_name = f"{args.policy}_{args.env}_{seed}"
		print("---------------------------------------")
		print(f"Policy: {args.policy}, Env: {args.env}, Seed: {seed}")
		print("---------------------------------------")
	
		
		summary_dir = os.path.join("./results", file_name)
		model_dir = os.path.join("./models",file_name)
	
		env = gym.make(args.env)
	
		# Set seeds
		env.seed(args.seed)
		env.action_space.seed(args.seed)
		torch.manual_seed(args.seed)
		np.random.seed(args.seed)
		
		state_dim = env.observation_space.shape[0]
		action_dim = env.action_space.shape[0] 
		reward_dim = env.reward_space
		max_action = float(env.action_space.high[0])
	
		kwargs = {
			"state_dim": state_dim,
			"action_dim": action_dim,
			"reward_dim": reward_dim,
			"max_action": max_action,
			"discount": args.discount,
			"tau": args.tau,
		}
	     
		policy = OurDDPG.Actor(state_dim = state_dim, action_dim = action_dim, reward_dim = reward_dim, max_action = max_action).to(device)
		policy.load_state_dict(torch.load(model_dir + "/15_actor"))
		state, done = env.reset(), False
		p = get_pref(reward_dim)
		episode_reward = np.zeros(reward_dim)
		episode_timesteps = 0
		episode_num = 1
		a = np.arange(0,1,0.01)
		table = np.stack((a,1-a),-1)
		#table=table
		total_reward_vec = np.zeros( (table.shape[0], env.reward_num))
		total_reward = np.zeros( table.shape[0])
		for i in range(table.shape[0]):
			state ,done = env.reset(), False
			p = table[i]
			episode_rewards = np.zeros((epi_num, reward_dim))
			episdoe_reward = np.zeros(reward_dim)
	
			for j in range(epi_num):
				done = False
				while not done:
					episode_timesteps += 1
	                
					# Select action randomly or according to policy
					action = sa(policy, np.array(state), p)
	                
					# Perform action
					next_state, reward, done, _ = env.step(action) 
					done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0
	                
	                
					state = next_state
					episode_reward += reward
	                
					# Train agent after collecting sufficient data
	                
					if done: 
						# +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
						print(f"Episode Num: {episode_num+1} Episode T: {episode_timesteps} Weight{p} Reward: {episode_reward}")
						# Reset environment
						episode_num += 1
						
                    
						state = env.reset()
						episode_rewards[j] = episode_reward
						episode_reward = np.zeros(reward_dim)
						episode_timesteps = 0
			print('='*70)
			m = np.mean(episode_rewards, 0)
			print(m, np.dot(m,p))
			total_reward_vec[i] = m
			total_reward[i] = np.dot(m ,p)
			print('='*70)
		total_reward_vec = total_reward_vec[total_reward_vec[:, 0].argsort()]
		h = compute_hv(total_reward_vec, ref)
		hv.append(h)
		utility.append(np.mean(total_reward))
		np.save(f'pd/Vanila_{env_name}_{seed}',total_reward_vec)
		
                
	
	
	print('Hyper volume')
	print(hv)
	print(np.mean(hv))
	print('utility')
	print(utility)
	print(np.mean(utility))
	print(env_name)	
