import numpy as np
from Environment.env import Environment
import collections
import pickle
from tqdm import tqdm
import random
from pymoo.factory import get_performance_indicator

class ExpectedHypervolumeImprovement:
	def __init__(self, domain_num, f_num, domain_min_points, sample_Gaussian_num, sub_sampling, sub_sampling_num):
		self.domain_num = domain_num
		self.f_num = f_num
		self.domain_min_points = domain_min_points
		self.sample_Gaussian_num = sample_Gaussian_num
		self.sub_sampling = sub_sampling
		self.sub_sampling_num = sub_sampling_num


	def countHypervolume(self, point):
		hv = get_performance_indicator("hv", ref_point=np.array([0 for _ in range(self.f_num)]))

		return hv.do(np.array(point))


	def select_action(self, state_action_pairs):
		mu = state_action_pairs[:, :self.f_num]
		sigma = state_action_pairs[:, self.f_num:self.f_num*2]

		ExI = []

		for i in range(self.domain_num):
			y = []
			for j in range(self.f_num):
				sample_Gaussian = np.random.normal(mu[i][j], sigma[i][j], self.sample_Gaussian_num)
				y.append(sample_Gaussian.reshape(-1, 1))
			y = np.concatenate(y, axis=1)	# len(y) = sample_Gaussian_num, len(y[i]) = f_num

			improvement = []
			# -yt + min_point
			point = [[-k+l for (k, l) in zip(y[m], self.domain_min_points)] for m in range(self.sample_Gaussian_num)]
			for j in range(self.sample_Gaussian_num-1):
				improvement.append(self.countHypervolume(point[:j+2]) - self.countHypervolume(point[:j+1]))
			ExI.append(sum(improvement)/self.sample_Gaussian_num)

		action = np.argmax(ExI)
		
		return action

data = []
episodes = 100
f_num = 3
domain_num = 1000
view_num = 100
T = 100
function_type="train"
env = Environment(T=T, domain_num=domain_num, view_num=view_num, f_num=f_num, function_type=function_type, seed=0)
episode_data = {}
eps = 0.5
for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
    episode_data[k] = []
for e in range(episodes):
    actions = np.random.choice(domain_num, T, replace=True)
    state_actions, ehi_sa = env.reset(seed = 3100+e*10, new_ls=True)
    r = random.random()
    print("e: {} | G: {}".format(e, "random" if r < eps else "EI"))
    for t in tqdm(range(T)):
        if r < eps:
            action = random.randint(0, domain_num-1)        
        else:
            learner = ExpectedHypervolumeImprovement(domain_num=domain_num, f_num=f_num, domain_min_points=env.domain_min_points,
                sample_Gaussian_num=2, sub_sampling=0, sub_sampling_num=100)
            action = learner.select_action(ehi_sa)
        next_state_actions, ehi_sa, reward, done, _ = env.step(env.X[action])
        episode_data['observations'].append(state_actions.reshape(-1).astype('float32'))
        episode_data['next_observations'].append(next_state_actions.reshape(-1).astype('float32'))
        episode_data['actions'].append(np.array([action]))
        episode_data['rewards'].append(reward.astype('float32'))
        episode_data['terminals'].append(True if t==T-1 else False)
        state_actions = next_state_actions    
    
for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
    episode_data[k] = np.stack(episode_data[k])
    print(episode_data[k].shape)
data.append(episode_data) 
with open('Trajectory-Transformer/trajectory/datasets/random_data/{}_f{}_view{}_reward_e{}'.format(function_type, f_num, view_num, episodes) + '.pkl', 'wb') as f:
    pickle.dump(data, f)