import numpy as np
import gpytorch
import torch
from botorch.posteriors.posterior import Posterior
from botorch.sampling.base import MCSampler
from botorch.sampling.get_sampler import GetSampler
#Modify for sliding window

class SpaceTimeGPPosterior(Posterior):
    def __init__(self, distribution):
        self.distribution = distribution
        
    @property
    def device(self) -> torch.device:
        return self.distribution.mean.device
        
    @property
    def dtype(self) -> torch.dtype:
        return self.distribution.mean.dtype
        
    def rsample(self, sample_shape: torch.Size = None) -> torch.Tensor:
        if sample_shape is None:
            sample_shape = torch.Size([1])
        return self.distribution.rsample(sample_shape)
        
    @property
    def batch_range(self) -> tuple[int, int]:
        return (0, len(self.distribution.mean.shape) - 1)

    @property
    def base_sample_shape(self) -> torch.Size:
        # This should return the shape of the base samples needed
        return self.distribution.mean.shape

    @property
    def mean(self) -> torch.Tensor:
        """Returns the posterior mean."""
        return self.distribution.mean
    
    @property
    def stddev(self) -> torch.Tensor:
        return self.distribution.stddev
    
    @property
    def variance(self) -> torch.Tensor:
        return self.distribution.variance

    def confidence_region(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Returns the 95% confidence interval bounds.
        
        Returns:
            tuple[torch.Tensor, torch.Tensor]: Lower and upper confidence bounds
        """
        std = self.distribution.stddev
        return (
            self.mean - 1.96 * std,
            self.mean + 1.96 * std
        )

class SpaceTimeGPSampler(MCSampler):
    def __init__(self, sample_shape: torch.Size, seed: int = None):
        super().__init__(sample_shape=sample_shape, seed=seed)
        
    def forward(self, posterior: SpaceTimeGPPosterior) -> torch.Tensor:
        return posterior.rsample(sample_shape=self.sample_shape)
        
    @property
    def batch_range(self) -> tuple[int, int]:
        return (0, 0)

# Register the sampler
@GetSampler.register(SpaceTimeGPPosterior)
def _get_spacetime_sampler(posterior: SpaceTimeGPPosterior, sample_shape: torch.Size, seed: int = None) -> SpaceTimeGPSampler:
    return SpaceTimeGPSampler(sample_shape=sample_shape, seed=seed)


class SpaceTimeGPModel(gpytorch.models.ExactGP):
	num_outputs = 1

	def __init__(self, space_kernel, space_args, time_kernel, time_args, train_x, train_y, likelihood):
		"""Build the surrogate model

		Args:
				space_kernel (gpytorch.kernels.Kernel class): the spatial kernel
				space_args (list): the spatial kernel arguments
				time_kernel (gpytorch.kernels.Kernel class): the temporal kernel class
				time_args (list): the temporal kernel arguments
				train_x (np.array): the training dataset
				train_y (np.array): the labels
				likelihood (gpytorch.likelihood): the likelihood function
		"""
		if train_y.ndim > 1:
			train_y = train_y.squeeze(-1) 

		super(SpaceTimeGPModel, self).__init__(train_x, train_y, likelihood)
		if train_x.ndim != 1:
			self.d = train_x.shape[1]
			self.train_x = train_x
		else:
			self.d = train_x.shape[0]
			self.train_x = train_x.unsqueeze(0)
	
		self.train_y = train_y
		self.likelihood = likelihood
		self.mean_module = gpytorch.means.ConstantMean()
		self.covar_module = gpytorch.kernels.ScaleKernel(space_kernel(*space_args, active_dims=torch.tensor(range(self.d-1))) * time_kernel(*time_args, active_dims=torch.tensor([self.d - 1])))
	
	def forward(self, x):
		"""Use the model on the input

		Args:
				x (torch.Tensor): the input

		Returns:
				gpytorch.distributions.MultivariateNormal: the posterior distribution for input
		"""
		mean_x = self.mean_module(x)
		covar_x = self.covar_module(x)
		return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
	
	def fit(self):
		"""Find the hyperparameters from data
		"""
		# Ensure model and likelihood are on the CUDA device
		self.to(torch.device('cuda'))
		self.likelihood.to(torch.device('cuda'))

		# Find optimal model hyperparameters
		self.train()
		self.likelihood.train()

		# Use the adam optimizer
		optimizer = torch.optim.Adam(self.parameters(), lr=0.1)	# Includes GaussianLikelihood parameters

		# "Loss" for GPs - the marginal log likelihood
		mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)

		n_iterations = 150
		for i in range(n_iterations):
			# Zero gradients from previous iteration
			optimizer.zero_grad()
			# Output from model
			output = self(self.train_x.to(torch.device('cuda')))
			# Calc loss and backprop gradients
			loss = -mll(output, self.train_y.to(torch.device('cuda'))).mean()
			loss.backward()

			optimizer.step()

		self.eval()
		self.likelihood.eval()
		
	def posterior(self, X, posterior_transform=None):
		# Get the base posterior
		device = torch.device('cuda')
		X = X.to(device)
		self.to(device)
		self.likelihood.to(device)
		posterior = self.likelihood(self(X))

		return posterior#SpaceTimeGPPosterior(posterior)
		#return self.likelihood(self(X))
	
	
	def get_kernel_log_hyperparameters(self):
		"""Return the log of kernel hyperparameters

		Returns:
				np.array: the log of the kernel hyperparameters
		"""
		return np.log(np.array([self.covar_module.outputscale.item(), self.covar_module.base_kernel.kernels[0].lengthscale.item(), self.covar_module.base_kernel.kernels[1].lengthscale.item(), self.likelihood.noise.item()]))
	
	def set_parameters(self, lmbd, lS, lT):
		"""Setter for the kernel parameters

		Args:
				lmbd (float): the scale of the covariance function
				lS (float): the spatial lengthscale of the covariance function
				lT (float): the temporal lengthscale for the covariance function
		"""
		self.covar_module.outputscale = lmbd
		self.covar_module.base_kernel.kernels[0].lengthscale = lS
		self.covar_module.base_kernel.kernels[1].lengthscale = lT


def learn_model_space_time(xx_tt, space_kernel, space_kernel_args, time_kernel, time_kernel_args, yy_normalized):
	"""Helper function to build a surrogate model

	Args:
			xx_tt (np.float): training inputs
			space_kernel (gpytorch.kernels.Kernel class): the spatial kernel class
			space_kernel_args (list): the arguments for the spatial kernel class
			time_kernel (gpytorch.kernels.Kernel class): the temporal kernel class
			time_kernel_args (list): the arguments for the temporal kernel class
			yy_normalized (np.array): the training labels

	Returns:
			SpaceTimeModelGP: the surrogate model, trained on the training data
	"""
	likelihood = gpytorch.likelihoods.GaussianLikelihood()
	gpr = SpaceTimeGPModel(space_kernel, space_kernel_args, time_kernel, time_kernel_args, (xx_tt), (yy_normalized), likelihood)
	gpr.fit()

	return gpr