from __future__ import annotations
import warnings

warnings.filterwarnings("ignore")

import math
import torch
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models.gp_regression import SingleTaskGP
from gpytorch.kernels import Kernel, ProductKernel, AdditiveKernel, ScaleKernel
from gpytorch.means import ZeroMean
from gpytorch.likelihoods import GaussianLikelihood

class MLLGP(SingleTaskGP):
    def __init__(
        self,
        train_X: torch.Tensor,
        train_Y: torch.Tensor,
        kernel: Kernel,
        noise_var: float = 0.0001, 
        n_epochs: int = 500,
        device: str = "cpu",
        dtype: torch.dtype = torch.float64,
        normalize_x=False, 
    ):
        # Initialize the kernel of the model
        feature_dim = train_X.shape[-1]
        self.init_kernel_params(kernel, feature_dim)
        likelihood = GaussianLikelihood()
        likelihood.noise = noise_var

        self.normalize_x = normalize_x
        SingleTaskGP.__init__(
            self,
            train_X=train_X.to(device).to(dtype),
            train_Y=train_Y.to(device).to(dtype),
            likelihood=likelihood,
            mean_module=ZeroMean(),
            covar_module=kernel,
        )
        self.noise_var = noise_var

        self.kernel = kernel
        self.n_epochs = n_epochs
        self.device = device
        self.dtype = dtype

        self._train_model()

    def _train_model(self):
        """
        Implements a simple training procedure for the GP model
        (exact marginal log likelihood from gpytorch).
        """
        self.train()
        self.likelihood.train()

        # Define the optimizer with current tolerance_change
        optimizer = torch.optim.LBFGS(self.parameters(), line_search_fn="strong_wolfe", tolerance_change=1e-7)
        
        # Loss function for the GP
        mll = ExactMarginalLogLikelihood(self.likelihood, self)
        
        def closure():
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            output = self(self.train_inputs[0]) # type: ignore 
            loss = -mll(output, self.train_targets) # type: ignore
            if loss.requires_grad:
                loss.backward()
            return loss
        
        # Run the optimization
        for i in range(self.n_epochs):
            loss = optimizer.step(closure)
            if i % 100 == 0 or i == self.n_epochs-1:
                print(f"{i}/{self.n_epochs} - {loss.detach()}")

        self.eval()
        self.likelihood.eval()

        for param_name, param in self.named_parameters():
            constraint = self.constraint_for_parameter_name(param_name)
            
            if constraint is not None:
                # Transform the parameter to constrained space
                constrained_value = constraint.transform(param.data)
                # Handle different parameter shapes
                if constrained_value.numel() == 1:
                    print(f"{param_name::<40} {constrained_value.item():.4f} (constrained)", flush=True) 
            else:
                # If no constraint found, print raw parameter
                if param.numel() == 1:
                    print(f"{param_name::<40} {param.item():.4f} (unconstrained)", flush=True)


    def condition_on_observations(
        self, X: torch.Tensor, Y: torch.Tensor, **kwargs
    ) -> MLLGP:
        """
        Returns a new GP conditioned on the provided observations.
        """
        X = X.to(self.train_inputs[0])
        Y = Y.to(self.train_targets)

        train_X = torch.cat([self.train_inputs[0], X]) #  type: ignore
        train_Y = torch.cat([self.train_targets.unsqueeze(-1), Y])

        return MLLGP(
            train_X=train_X, 
            train_Y=train_Y, 
            kernel=self.kernel, 
            noise_var=self.noise_var, 
            n_epochs=self.n_epochs, 
            device=self.device, 
            dtype=self.dtype, 
            normalize_x=self.normalize_x,
        )

    def init_kernel_params(self, kernel, feature_dim):
        # Check if the kernel is a composition kernel
        if isinstance(kernel, (ProductKernel, AdditiveKernel)):
            # Recursively initialize sub-kernels
            for sub_kernel in kernel.kernels:
                self.init_kernel_params(sub_kernel, feature_dim)
        elif isinstance(kernel, ScaleKernel):
            self.init_kernel_params(kernel.base_kernel, feature_dim)

        # For kernels with length scale, set it to sqrt(input_dim)
        if kernel.has_lengthscale:
            kernel.lengthscale = math.sqrt(feature_dim)

        # For kernels with outputscale, set it to 1
        if hasattr(kernel, 'raw_outputscale'):
            kernel.outputscale = 1.0
            
        # For kernels with variance, set it to 1
        if hasattr(kernel, 'raw_variance'):
            kernel.variance = 1.0

    