import random
import numpy as np
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), './'))
from function_preprocessing import getFuntion, countHypervolume, domain
from benchmark_functions import set_noise_level, set_YAHPO_index
import torch
from botorch import fit_gpytorch_mll
from botorch.models import SingleTaskGP, ModelListGP
from gpytorch.constraints import GreaterThan
from gpytorch.kernels import RBFKernel, ScaleKernel, MaternKernel
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models.transforms.outcome import Standardize
from yahpo_gym import *
from gpytorch.mlls import SumMarginalLogLikelihood

YAHPO_INSTANCES = [3945,168868,167152,167200,168330,189862,189909,167181,167149,126026,168910,167190,189906,146212,189865,167168,168329,189873,34539,189866,167104,189908,168335,126025,189354,126029,167161,167184,189905,168908,167185,167201,7593,168331]

class GaussianProcess:
	def __init__(self, x, y, kernel, kernel_ls, f_num):
		self.GP = []
		self.kernel_ls = kernel_ls
		self.kernel = kernel
		self.f_num = f_num
		train_X = torch.tensor(x, dtype=float)
		train_Y = torch.tensor(y, dtype=float)
		for i in range(f_num):
			if kernel[i] == "RBF":
				covar_module = RBFKernel()
			elif kernel[i] == "matern52":
				covar_module = MaternKernel()
			covar_module.lengthscale = kernel_ls[i]
			self.GP.append(SingleTaskGP(train_X=train_X, train_Y=train_Y[:,i].unsqueeze(1), covar_module = covar_module, outcome_transform=Standardize(m=1)))

	def fit(self, x, y):
		train_X = torch.tensor(x, dtype=float)
		train_Y = torch.tensor(y, dtype=float)
		for i in range(self.f_num):
			if self.kernel[i] == "RBF":
				covar_module = RBFKernel()
			elif self.kernel[i] == "matern52":
				covar_module = MaternKernel()
			# covar_module.lengthscale = self.kernel_ls[i]
			self.GP.append(SingleTaskGP(train_X=train_X, train_Y=train_Y[:,i].unsqueeze(1), covar_module = covar_module, outcome_transform=Standardize(m=1)))
			mll = ExactMarginalLogLikelihood(likelihood=self.GP[i].likelihood, model=self.GP[i])
			fit_gpytorch_mll(mll)

	def construct_state_action_pair(self, domain, y_star, t):
		means = []
		variances = []
		for i in range(self.f_num):
			output = self.GP[i].posterior(torch.tensor(domain))
			means.append(output.mean.cpu().detach().numpy())
			variances.append(output.variance.cpu().detach().numpy())
		means = torch.tensor(np.array(means)).squeeze().T
		variances = torch.tensor(np.array(variances)).squeeze().T
		state_action_pairs = torch.cat((means, variances), dim = 1)
		state_action_pairs = torch.cat((state_action_pairs, torch.tile(torch.tensor([y_star]), (len(domain),1))), dim = 1)
		state_action_pairs = torch.cat((state_action_pairs, torch.tile(torch.tensor([t]), (len(domain),1))), dim = 1)
		return state_action_pairs.detach().numpy()

class Environment:
	def __init__(self, T, domain_size, f_num, function_type, seed=0, 
			  new_reward = False, 
			  noise_level = 0.1, 
			  ls_learned_freq = 10, 
			  online_ls = 0,
			  ls_weight = 1):
		
		# store argument 
		self.T = T
		self.domain_size = domain_size
		self.f_num = f_num
		self.function_type = function_type
		self.seed = seed
		self.new_reward = new_reward
		self.noise_level = noise_level
		self.episode = 0
		set_noise_level(noise_level)
		self.ls_learned_freq = ls_learned_freq
		self.online_ls = online_ls
		self.ls_weight = ls_weight
		# reset history 
		self.history = dict()
		self.history["x"] = []
		self.history["y_observed"] = []
		self.history["y_true"] = []
		self.history["hypervolume_observed"] = [0]
		self.history["hypervolume_true"] = [0]
		self.history["ls_esti"] = []
		self.t = 0
		# set ransom seed
		if seed > 0:
			np.random.seed(seed)
			torch.manual_seed(seed)
			random.seed(seed)

		# update function
		self.X = domain(function_type, domain_size, seed)
		self.domain_dim = np.shape(self.X)[-1]
		self.ls = [torch.tensor([[0.001]*self.domain_dim]) for i in range(f_num)]


		if function_type == "YAHPO":
			set_YAHPO_index(str(YAHPO_INSTANCES[self.episode]))
		
		self.F, self.pareto_front, self.min_function_values, self.kernel, self.kernel_ls = getFuntion(self.X, f_num = self.f_num, function_type =  self.function_type)
		self.history["ls_true"] = self.kernel_ls
		self.history["kernel_true"] = self.kernel

		self.history["pareto_front"] = self.pareto_front

	def reset(self, seed=0):
		# clear history
		self.history = dict()
		self.history["x"] = []
		self.history["y_observed"] = []
		self.history["y_true"] = []
		self.history["hypervolume_observed"] = [0]
		self.history["hypervolume_true"] = [0]
		self.history["ls_true"] = []
		self.history["ls_esti"] = []
		self.ls = [torch.tensor([[0.001]*self.domain_dim]) for i in range(self.f_num)]
		# set ransom seed
		if seed > 0:
			np.random.seed(seed)
			torch.manual_seed(seed)
			random.seed(seed)
		
		# update function
		self.X = domain(self.function_type, self.domain_size, self.seed)

		if self.function_type == "YAHPO":
			set_YAHPO_index(str(YAHPO_INSTANCES[self.episode]))

		self.F, self.pareto_front, self.min_function_values, self.kernel, self.kernel_ls = getFuntion(self.X, f_num = self.f_num, function_type =  self.function_type)
		self.history["ls_true"] = self.kernel_ls
		self.history["kernel_true"] = self.kernel

		self.history["pareto_front"] = self.pareto_front
		self.episode += 1

	def getYt(self, x):
		self.history["x"].append(x)
		y_true = []
		y_observed = []
		for i in range(len(self.F)):
			y = float(self.F[i](x))
			y_true.append(y)
			y_observed.append(y_true[-1] + np.random.normal(0, self.noise_level, 1)[0])
		self.history["y_true"].append(y_true)
		self.history["y_observed"].append(y_observed)
		self.t = len(self.history["x"])

	def getReward(self):
		if self.f_num == 1:
			reward = max(self.history["y_observed"][:]) - max(self.history["y_observed"][:-1])
		else:
			self.history["hypervolume_observed"].append(countHypervolume(np.array(self.history["y_observed"]), np.array(self.min_function_values)))
			if self.new_reward == True:
				reward = (self.history["hypervolume_observed"][self.t] - self.history["hypervolume_observed"][self.t-1])/((1.1**self.f_num) - self.history["hypervolume_observed"][self.t])
			else:
				reward = self.history["hypervolume_observed"][self.t] - self.history["hypervolume_observed"][self.t-1]
		return reward

	def getRegret(self):
		if (self.f_num == 1):
			regret = self.domain_max_points - max(self.history["y_true"])
		else:
			self.history["hypervolume_true"].append(countHypervolume(np.array(self.history["y_true"]), np.array(self.min_function_values)))
			regret = self.history["pareto_front"] - self.history["hypervolume_true"][self.t]
		return regret

	def fit_gp(self, t):
		train_X = torch.tensor(np.array(self.history["x"]), dtype=float)
		train_Y = torch.tensor(self.history["y_observed"], dtype=float)
		GP_list = [SingleTaskGP(train_X=train_X, train_Y=train_Y[:,i].unsqueeze(1), 
									outcome_transform=Standardize(1)) for i in range(self.f_num)]
		if self.online_ls == 0:
			for i in range(self.f_num):
				GP_list[i].covar_module.base_kernel.lengthscale = self.kernel_ls[i] # pre-compute ls
			return ModelListGP(*GP_list)
		else:
			for i in range(self.f_num):
				GP_list[i].covar_module.base_kernel.lengthscale = self.ls[i] # initial ls

		model = ModelListGP(*GP_list)
		# model.covar_module.base_kernel.lengthscale = torch.tensor([[0.001]*self.domain_dim*self.f_num]).resize(self.f_num,1,self.domain_dim) # important step that make the learned ls be correct
		# model.likelihood.noise_covar.register_constraint("raw_noise", GreaterThan(1e-5))
		if t % self.ls_learned_freq == 0:
			mll = SumMarginalLogLikelihood(model.likelihood, model)# mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
			try:
				fit_gpytorch_mll(mll)
			except RuntimeError:
				print("Something wrong")
			for i in range(self.f_num):
				GP_list[i].covar_module.base_kernel.lengthscale *= self.ls_weight
				self.ls[i] = GP_list[i].covar_module.base_kernel.lengthscale
			self.history["ls_esti"].append(self.ls)
		return model

	def step(self, x):
		self.getYt(x)
		reward = self.getReward()
		regret = self.getRegret()
		y_star = max(self.history["y_observed"])
		return y_star, float(reward), float(regret)

def construct_state_action_pair(domain, gp, y_star, t):
	output = gp.posterior(torch.tensor(domain))
	mean = output.mean.cpu()
	variance = output.variance.cpu()
	state_action_pairs = torch.cat((mean, variance), dim = 1)
	state_action_pairs = torch.cat((state_action_pairs, torch.tile(torch.tensor([y_star]), (len(domain),1))), dim = 1)
	state_action_pairs = torch.cat((state_action_pairs, torch.tile(torch.tensor([t]), (len(domain),1))), dim = 1)
	return state_action_pairs.detach().numpy()

if __name__ == '__main__':
	f_num = 2
	T = 10
	domain_size = 5
	a = Environment(T = 100, domain_size = 1000, f_num = 2, function_type = "RBF_0.05", seed = 1)

	for i in range(500):
		y_star, reward, regret = a.step([random.random()])
		# gp = GaussianProcess(np.array(a.history["x"]), np.array(a.history["y_observed"]), a.kernel, a.kernel_ls, a.f_num)
		# state_action_pairs = gp.construct_state_action_pair(a.X, y_star, i/T)
		gp = a.fit_gp(i)
		state_action_pairs = construct_state_action_pair(a.X, gp, y_star, i/T)
		b = 0