import torch
from botorch.models.gp_regression import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition.analytic import UpperConfidenceBound
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.kernels import MaternKernel, ScaleKernel
import time
import pickle
import os
import json
import gpytorch
from bayes_opt import simulation
import numpy as np
from bayes_opt.gp import learn_model_space_time
from bayes_opt.WDBO_algo import learn_model_space_time_WDBO
from bayes_opt.BOBA import ActiveInferenceAcquisitionFunction
from ctypes import cdll
# Load the DLL file

path = #include path for wdbo_criterion pyd in the folder
my_dll = cdll.LoadLibrary(path)

import wdbo_criterion

class Bayesian_Optimization:
	def __init__(self, 
			  time_limit=600, 
			  num_samples=15, 
			  input_dim=3, 
			  bounds=[0.0, 1.0],
			  simulation_function='powell',
			  use_evenly_spaced_time=False,
			  num_time_points=100,
			  Y_best=None,
			  debug_timing=False,
			  verbose_timing=True
			  ):

		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		print(f"Using device: {self.device}")
		self.num_samples = num_samples
		self.input_dim = input_dim
		self.time_limit = time_limit
		self.bounds = torch.tensor(bounds, dtype=torch.double, device=self.device)
		self.elapsed_time = []  # List to store regret values over time
		self.dataset = None
		self.X_train_WDBO = None  # New variable for GP training data, CHANGE TO WDBO
		self.Y_train_WDBO = None  # New variable for GP training data, CHANGE TO WDBO
		self._current_time = None
		self.prev_mean = None
		self._budget = 1.0
		self._alpha = 0.25
		self.A = np.zeros((0, 1))  # Empty array with 5 columns
		self.G = np.zeros((0, 3))    # Empty array with 5 columns
		self._lambda, self._lS, self._lT, self._noise = None, None, None, None
		self._spatial_kernel = gpytorch.kernels.MaternKernel
		self._temporal_kernel = gpytorch.kernels.MaternKernel
		self._spatial_kernel_wdbo = self.get_wdbo_kernel_class(gpytorch.kernels.MaternKernel)
		self._temporal_kernel_wdbo = self.get_wdbo_kernel_class(gpytorch.kernels.MaternKernel)
		self.simulation_function = getattr(simulation, simulation_function)
		self.Y_best = torch.tensor(Y_best, device=self.device)
		
		# Time handling options
		self.use_evenly_spaced_time = use_evenly_spaced_time
		self.num_time_points = num_time_points
		self.time_point_index = 0
		if self.use_evenly_spaced_time:
			self.time_points = torch.linspace(0, 1, num_time_points, device=self.device)
		else:
			self.time_points = None
		self.debug_timing = debug_timing
		self.verbose_timing = verbose_timing

	def get_wdbo_kernel_class(self, gpytorch_kernel_class):
		"""Correspondance between gpytorch kernels classes and wdbo-criterion kernels classes.

		Args:
				gpytorch_kernel_class (gpytorch.kernels.Kernel class): the kernel class in gpytorch

		Returns:
				wdbo_criterion.Kernel class: the kernel class in wdbo_criterion
		"""
		if gpytorch_kernel_class == gpytorch.kernels.RBFKernel:
			return wdbo_criterion.RBFKernel
		if gpytorch_kernel_class == gpytorch.kernels.MaternKernel:
			return wdbo_criterion.MaternKernel
		
		return None
	
	# Initialize dataset (input points and function values)
	def generate_initial_data(self):
		# if self.mode == "BB" or self.mode == "EEG":
		# 	X = torch.tensor([[40]], dtype=torch.double, device=self.device)  # Ensure dtype is float
		#else:
		X = torch.rand(self.num_samples, self.input_dim+1, dtype=torch.double, device=self.device)
		
		X = self.denormalize_x(X)

		Y = torch.empty(self.num_samples, dtype=torch.double, device=self.device)
		
		# Initialize start time for real-time mode
		if not self.use_evenly_spaced_time:
			self.start_time = time.time()
		for i,_ in enumerate(X):
			# Get current time point based on mode
			if self.use_evenly_spaced_time:
				norm_time = self.get_current_time_point()
				denorm_time = self.denormalize_x(torch.tensor(norm_time, device=self.device))
			else:
				current_time = time.time() - self.start_time
				denorm_time = current_time * ((self.bounds[1] - self.bounds[0])/self.time_limit) + self.bounds[0]
			
			# Wrap the powell function with noise
			X[i, -1] = denorm_time
			self.noisy_powell = simulation.add_noise(self.simulation_function, noise_level=0.025, bounds=self.bounds.to(self.device), dim=self.input_dim+1) ###Adjusted noise from 0.025
			y_value = self.noisy_powell(X[i], bounds=self.bounds.to(self.device))
			Y[i] = y_value

		# Update GP for the dataset using fit_gp_models
		self.output_dim = Y.shape[1] if Y.ndimension() > 1 else 1

		return X, Y

	def normalize_time(self, time):
		#return time * ((self.bounds[1] - self.bounds[0])/self.time_limit) + self.bounds[0]
		norm_time = time * ((self.bounds[1] - self.bounds[0])/self.time_limit) + self.bounds[0]
		return self.normalize_x(norm_time)

	def get_current_time_point(self):
		"""
		Get the current time point based on the time handling mode.
		
		Returns:
			float: Current normalized time point
		"""
		if self.use_evenly_spaced_time:
			# Use evenly spaced time points
			if self.time_point_index >= len(self.time_points):
				# If we've used all time points, cycle back to the beginning
				self.time_point_index = 0
			current_time = self.time_points[self.time_point_index].item()
			self.time_point_index += 1
			return current_time
		else:
			# Use real-time processing
			current_time = time.time() - self.start_time
			return self.normalize_time(current_time)

	def normalize_x(self, x):	
		"""Linear map from the spatial domain of the objective function to [0, 1]^d

		Args:
				x (np.array): the input

		Returns:
				np.array: the input mapped in [0, 1]^d
		"""
		x_norm = (x - self.bounds[0]) / (self.bounds[1] - self.bounds[0])
		return x_norm #torch.cat((x_norm, x[:,-1].unsqueeze(-1)), dim=1)
	
	def denormalize_x(self, x):	
		"""Linear map from the spatial domain of the objective function to [0, 1]^d

		Args:
				x (np.array): the input

		Returns:
				np.array: the input mapped in [0, 1]^d
		"""
		if x.dim() == 1:
			x = x.unsqueeze(0)


		x_denorm = (x * (self.bounds[1] - self.bounds[0])) + self.bounds[0]
		return x_denorm #torch.cat((x_denorm, x[:,-1].unsqueeze(-1)), dim=1)
	
	def normalize_y(self, y):
		"""Standardize the input (i.e. subtract the empirical mean, divide by the empirical standard deviation)

		Args:
				y (np.array): the input

		Returns:
				np.array: the input standardized
		"""
		return (y - torch.mean(y)) / torch.std(y)
	
	def fit_space_time_gp_models(self, X, Y):
		fit_start_time = time.time()
		if self.output_dim == 1:
			Y = Y.unsqueeze(-1)

		gp_models = []
		
		# 2. Use GPU if available
		X = X.to(self.device)
		Y = Y.to(self.device)
		X_norm = self.normalize_x(X)
		Y_norm = self.normalize_y(Y)
		
		normalize_time = time.time() - fit_start_time
		
		gp_fit_start = time.time()
		gp_models = learn_model_space_time_WDBO(xx_tt=X_norm, ####Changed model to WDBO
									 space_kernel=self._spatial_kernel, 
									 space_kernel_args=[2.5], 
									 time_kernel=self._temporal_kernel, 
									 time_kernel_args=[1.5], 
									 yy_normalized=Y_norm).to(self.device)
		gp_fit_time = time.time() - gp_fit_start
		
		hyperparam_start = time.time()
		self._lambda, self._lS, self._lT, self._noise = np.exp(gp_models.get_kernel_log_hyperparameters())
		hyperparam_time = time.time() - hyperparam_start
		
		total_fit_time = time.time() - fit_start_time

		
		return gp_models
	
	# Fit Gaussian Process models for each output dimension
	def fit_gp_models(self, X, Y):
		fit_start_time = time.time()
		if self.output_dim == 1:
			Y = Y.unsqueeze(-1)
		
		# 2. Use GPU if available
		X = X.to(self.device)
		Y = Y.to(self.device)
		X_norm = self.normalize_x(X)
		Y_norm = self.normalize_y(Y)

		normalize_time = time.time() - fit_start_time

		# 1. Cache kernel computations
		kernel_start = time.time()
		base_kernel = MaternKernel(nu=2.5, active_dims=list(range(X.shape[1]))).to(self.device)
		kernel_time = time.time() - kernel_start
		
		# 3. Optimize hyperparameters with fewer iterations
		gp_create_start = time.time()
		gp = SingleTaskGP(
			X_norm, Y_norm,
			covar_module=base_kernel,).to(self.device)
		gp_create_time = time.time() - gp_create_start
			
		mll_start = time.time()
		mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
		mll_time = time.time() - mll_start

		# 4. Configure fitting with early stopping and reduced iterations
		fit_gpytorch_start = time.time()
		fit_gpytorch_mll(
			mll,
		)
		fit_gpytorch_time = time.time() - fit_gpytorch_start
		
		total_fit_time = time.time() - fit_start_time
		
		if self.verbose_timing:
			print(f"fit_gp_models timing - Normalize: {normalize_time:.3f}s, Kernel: {kernel_time:.3f}s, GP create: {gp_create_time:.3f}s, MLL: {mll_time:.3f}s, Fit: {fit_gpytorch_time:.3f}s, Total: {total_fit_time:.3f}s")
		
		return gp


	def clean(self, t, verbose=False):
		"""Remove irrelevant observations from the dataset

		Args:
				t (float): the present time
				verbose (bool, optional): verbose output. Defaults to False.
		"""
		clean_start_time = time.time()
		
		if self._current_time is None:
			self._current_time = t
		
		self._budget *= (1.0 + self._alpha) ** ((t - self._current_time) / self._lT)

		# Cleaning loop for the dataset
		min_crit = 0
		clean_loop_start = time.time()
		mask_time = 0  # Initialize mask_time
		while self.X_train_WDBO.shape[0] > 2 and self._budget > min_crit:
			# Measures observations relevancy, this is the C++ code but unsure what the output would look like
			criteria_start = time.time()
			
			# Normalize the entire X_train_WDBO data (both spatial and time dimensions)
			X_norm = self.normalize_x(self.X_train_WDBO)
			
			# Extract spatial and time dimensions from normalized data
			X_spatial_norm = X_norm[:, :-1]
			time_norm = X_norm[:, -1]
			
			criteria = wdbo_criterion.wasserstein_criterion(
					np.ascontiguousarray(X_spatial_norm.cpu().numpy(), dtype=np.float64),
					np.ascontiguousarray(self.Y_train_WDBO.cpu().numpy(), dtype=np.float64),
					np.ascontiguousarray(time_norm.cpu().numpy(), dtype=np.float64),
					self.X_train_WDBO.shape[0],
					self.input_dim,
					self._lambda, self._noise,
					self._spatial_kernel_wdbo(*([self._lS] + [2.5])), self._temporal_kernel_wdbo(*([self._lT] + [1.5])),
					t,
					0, 1)
			criteria_time = time.time() - criteria_start
			
			# Find the least relevant observation
			sort_start = time.time()
			sorted_args = criteria.argsort()
			indices, criteria = sorted_args, criteria[sorted_args]
			idx_min, min_crit = (indices[0], criteria[0] + 1.0)
			idx_min = torch.tensor(idx_min, device=self.device)
			sort_time = time.time() - sort_start
			
			if verbose:
				print(f"Removal Budget: {self._budget if isinstance(self._budget, float) else self._budget.item()} // Least Relevant Observation: {idx_min.item()} // Relevancy: {min_crit} (i.e. {round(100 * min_crit / (self._budget if isinstance(self._budget, float) else self._budget.item()), 2)}% of budget)")

			# Remove it if the budget allows it
			if min_crit < self._budget:
				# Budget consumption
				if min_crit > 1.0:
					self._budget = self._budget / min_crit

				if verbose:
					print(f"Observation {idx_min} is removed")

				# Dataset update
				# Create a mask for all indices except idx_min
				mask_start = time.time()
				mask = torch.ones(self.X_train_WDBO.shape[0], dtype=torch.bool, device=self.device)
				mask[idx_min] = False
				
				# Filter the tensors using the mask
				self.X_train_WDBO = self.X_train_WDBO[mask]
				self.Y_train_WDBO = self.Y_train_WDBO[mask]
				mask_time = time.time() - mask_start
				
				gp_update_start = time.time()
				self.gp = self.fit_space_time_gp_models(self.X_train_WDBO, self.Y_train_WDBO)
				gp_update_time = time.time() - gp_update_start
			else:
				gp_update_start = time.time()
				self.gp = self.fit_space_time_gp_models(self.X_train_WDBO, self.Y_train_WDBO)
				gp_update_time = time.time() - gp_update_start

		clean_loop_time = time.time() - clean_loop_start
		total_clean_time = time.time() - clean_start_time

		self._current_time = t
		
		if self.verbose_timing:
			print(f"clean timing - Criteria: {criteria_time:.3f}s, Sort: {sort_time:.3f}s, Mask: {mask_time:.3f}s, GP update: {gp_update_time:.3f}s, Loop: {clean_loop_time:.3f}s, Total: {total_clean_time:.3f}s")


	def save_gp_models(self, models, session_id):
		with open(os.path.join('Population_models', f'GP{session_id}.pkl'), 'wb') as f:
			pickle.dump(models, f)

	def save_session_data(self, X_train, Y_train, session_id):
		data = {
			'X_train': X_train.tolist(),
			'Y_train': Y_train.tolist(),
			'time': self.elapsed_time,
			
			'intrinsic_values': [float(v) for v in np.array(self.intrinsic_values).flatten()],
			'extrinsic_values': [float(v) for v in np.array(self.extrinsic_values).flatten()],
		}
		with open(os.path.join('Population_models', f'JSON{session_id}.json'), 'w') as f:
			json.dump(data, f)

	# Bayesian Optimization Loop
	def bayesian_optimization(self, model_type="GP-UCB", BOBA_beta=1.0, BOBA_normalization=False):
		"""
		Bayesian optimization with configurable acquisition function.
		
		Args:
			model_type: GP-UCB, BOBA, WDBO, WDBO_BOBA
			BOBA_beta: Beta value for BOBA that determines the amount of exploration
			BOBA_normalization: Whether to normalize the BOBA extrinsic and intrinsic values
		"""
		bo_start_time = time.time()
		# Initialize start time for real-time mode
		if not self.use_evenly_spaced_time:
			self.start_time = time.time()
		
		init_data_start = time.time()
		self.X_train, self.Y_train = self.generate_initial_data()  # This is the initial dataset
		init_data_time = time.time() - init_data_start
		
		init_gp_start = time.time()
		if model_type == "GP-UCB" or model_type == "BOBA":
			self.gp = self.fit_gp_models(self.X_train, self.Y_train)
		else:
			self.gp = self.fit_space_time_gp_models(self.X_train, self.Y_train)
		init_gp_time = time.time() - init_gp_start

		if model_type == "WDBO" or model_type == "WDBO_BOBA" or model_type == "BOLT":
			if self.X_train_WDBO is None:
				self.X_train_WDBO = self.X_train.clone()  # Initialize GP training data
				self.Y_train_WDBO = self.Y_train.clone()  # Initialize GP training data

		#Change best values to select CFC for now
		best_val_start = time.time()
		if self.Y_train.dim() > 1:
			best_values = self.Y_train[:, 0].max(dim=0).values
			best_points = self.X_train[self.Y_train[:, 0].argmax().item(), :]
		else:
			best_values = self.Y_train.max().item()
			best_points = self.X_train[self.Y_train.argmax().item(), :]
		best_val_time = time.time() - best_val_start
		
		print(f"Initial data: X = {best_points}, Y = {best_values}")
		
		low_bounds = torch.zeros(self.input_dim+1)
		up_bounds = torch.ones(self.input_dim+1)

		self.intrinsic_values = []
		self.extrinsic_values = []
	
		# Main optimization loop
		iteration = 15
		init_total_time = time.time() - bo_start_time
		if self.verbose_timing:
			print(f"BO initialization timing - Init data: {init_data_time:.3f}s, Init GP: {init_gp_time:.3f}s, Best val: {best_val_time:.3f}s, Total init: {init_total_time:.3f}s")
		
		while (self.use_evenly_spaced_time and self.time_point_index < self.num_time_points) or \
			  (not self.use_evenly_spaced_time and time.time() - self.start_time < self.time_limit):
			iteration_start_time = time.time()
			
			# Get current time point based on mode
			time_point_start = time.time()
			normalized_time = self.get_current_time_point()
			time_point_time = time.time() - time_point_start

			# Fit GP models for each output dimension
			bounds_start = time.time()
			low_bounds[-1] = normalized_time
			up_bounds[-1] = normalized_time
			bounds = torch.stack([low_bounds, up_bounds]).to(self.device)
			beta = 0.2*(self.input_dim)*np.log(2*(iteration+self.num_samples))
			bounds_time = time.time() - bounds_start

			# Select acquisition function based on model_type
			acq_start = time.time()
			if model_type == "GP-UCB" or model_type == "WDBO":
				# Standard UCB acquisition function
				ucb = UpperConfidenceBound(self.gp, beta=beta)
				X_next, _ = optimize_acqf(
					acq_function=ucb,
					bounds=bounds,
					q=1,  # Get top 1 point
					num_restarts=20,  # How many times to restart the optimization
					raw_samples=512,  # How many samples to draw from the acquisition function
				)
				
			elif model_type == "BOBA" or model_type == "WDBO_BOBA":
				# BOBA model
				### Using priors
				if model_type == "BOBA":
					Y_train = torch.cat([self.Y_train, self.Y_best.unsqueeze(0)], dim=0)
				else:
					Y_train = torch.cat([self.Y_train_WDBO, self.Y_best.unsqueeze(0)], dim=0)

				Y_norm = self.normalize_y(Y_train)
				Y_best = Y_norm.max().item()
				boba = ActiveInferenceAcquisitionFunction(Y_best, self.gp, np.array(self.input_dim), 
											 num_observation_levels=2, 
											 current_time=normalized_time, 
											 beta=BOBA_beta,
											 BOBA_normalization=BOBA_normalization)
				X_next, intrinsic_val, extrinsic_val = boba()
				self.intrinsic_values.append(intrinsic_val)
				self.extrinsic_values.append(extrinsic_val)
				
	
			else:
				raise ValueError(f"Unknown model_type: {model_type}. Supported types: 'UCB', 'BOBA', 'R_GP_UCB', 'WDBO'")
			
			acq_time = time.time() - acq_start
			

			X_next = X_next.to(self.device)
			X_next = self.denormalize_x(X_next)
			
			Y_next = self.noisy_powell(X_next.squeeze(0), bounds=self.bounds.to(self.device)).to(self.device)
			
			# Ensure Y_next is a 1-dimensional tensor
			if Y_next.dim() != self.Y_train.dim():
				Y_next = Y_next.unsqueeze(0)

			# Append new data
			data_append_start = time.time()
			self.X_train = torch.cat([self.X_train, X_next])
			self.Y_train = torch.cat((self.Y_train, Y_next), dim=0)

			# Update WDBO data
			if model_type == "WDBO" or model_type == "WDBO_BOBA":
				self.X_train_WDBO = torch.cat([self.X_train_WDBO, X_next])
				self.Y_train_WDBO = torch.cat((self.Y_train_WDBO, Y_next), dim=0)
			data_append_time = time.time() - data_append_start

			# Update best observed values
			best_update_start = time.time()
			if Y_next.dim() > 1:
				if Y_next[:, 0].item() > best_values:
					best_values = Y_next[:, 0].item()
					best_points = X_next.squeeze().tolist()
			else:
				if Y_next.item() > best_values:
					best_values = Y_next.item()
					best_points = X_next.squeeze().tolist()
			best_update_time = time.time() - best_update_start

			# Update GP models based on model type
			gp_update_start = time.time()
			if model_type == "WDBO" or model_type == "WDBO_BOBA":
				# WDBO uses clean() method for GP updates
				self.clean(normalized_time, False)
			elif model_type == "GP-UCB" or model_type == "BOBA":
				self.gp = self.fit_gp_models(self.X_train, self.Y_train)
			else:
				# Standard GP update for other models
				self.gp = self.fit_space_time_gp_models(self.X_train, self.Y_train)
			gp_update_time = time.time() - gp_update_start
			
			iteration_total_time = time.time() - iteration_start_time
			self.elapsed_time.append(iteration_total_time)
			
			if self.verbose_timing:
				print(f"Iteration {self.time_point_index} timing breakdown:")
				print(f"  Time point: {time_point_time:.3f}s")
				print(f"  Bounds: {bounds_time:.3f}s")
				print(f"  Acquisition: {acq_time:.3f}s")
				print(f"  Data append: {data_append_time:.3f}s")
				print(f"  Best update: {best_update_time:.3f}s")
				print(f"  GP update: {gp_update_time:.3f}s")
				print(f"  Total iteration: {iteration_total_time:.3f}s")
			
			print(f"iteration {iteration} took {iteration_total_time:.4f}s")
			iteration += 1
			if iteration % 20 == 0:
				print(f"Iteration {iteration}: X = {X_next.cpu().numpy()}, Y = {Y_next.cpu().numpy()}, ")
				print(f"Best = {best_values} at X = {best_points} (took {iteration_total_time:.2f}s)")
			

		# Get the list of existing GP files
		save_start = time.time()
		existing_files = [f for f in os.listdir('Population_models') if f.startswith('GP') and f.endswith('.pkl')]
		
		# Determine the next number for the GP file
		if existing_files:
			max_num = max([int(f[2:-4]) for f in existing_files])
			session_id = max_num + 1
		else:
			session_id = 1

		self.save_gp_models(self.gp, session_id)
		self.save_session_data(self.X_train, self.Y_train, session_id)
		save_time = time.time() - save_start
		
		total_bo_time = time.time() - bo_start_time
		if self.verbose_timing:
			print(f"Final save timing: {save_time:.3f}s")
			print(f"Total BO time: {total_bo_time:.3f}s")
		
		print("Saved data")
		return self.X_train, self.Y_train, best_points, best_values
