"""Low-fidelity GP model - uses FixedNoiseGP (pure GPyTorch implementation)"""
import torch
import numpy as np
import gpytorch
from typing import Tuple


class FixedNoiseGPModel(gpytorch.models.ExactGP):
    """Fixed noise GP model."""
    
    def __init__(self, train_x, train_y, train_noise, likelihood):
        super(FixedNoiseGPModel, self).__init__(train_x, train_y, likelihood)
        
        # Use standard GP modules
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1])
        )
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class StandardGPModel(gpytorch.models.ExactGP):
    """Standard GP model (no fixed noise)."""
    
    def __init__(self, train_x, train_y, likelihood):
        super(StandardGPModel, self).__init__(train_x, train_y, likelihood)
        
        # Use standard GP modules
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1])
        )
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class LowFidelityGP:
    """Low-fidelity GP model - trains FixedNoiseGP based on ICL predictions."""
    
    def __init__(self, training_iter: int = 100, bounds: np.ndarray = None):
        """Initialize low-fidelity GP model.
        
        Args:
            training_iter: Number of GP hyperparameter optimization iterations (default 100)
            bounds: Input space bounds, shape (d, 2), each row is [min, max]
        """
        self.model = None
        self.likelihood = None
        self.X_train = None
        self.y_train = None
        self.noise_train = None
        
        # Normalization parameters (using task bounds or data statistics)
        self.X_mean = None
        self.X_std = None
        self.bounds = bounds  # Task bounds
        
        # Training parameters
        self.training_iter = training_iter
    
    def fit(self, X: np.ndarray, y: np.ndarray, noise: np.ndarray):
        """Train low-fidelity GP model.
        
        Args:
            X: Training inputs, shape (n, d)
            y: Training targets (LF prediction mean), shape (n,)
            noise: Observation noise variance (LF prediction variance), shape (n,)
        """
        if len(X) == 0:
            raise ValueError("Training data cannot be empty")
        
        # Ensure correct data types (handle object type from pandas)
        X = np.asarray(X, dtype=np.float32)
        y = np.asarray(y, dtype=np.float32)
        noise = np.asarray(noise, dtype=np.float32)
        
        # Save original training data
        self.X_train = X
        self.y_train = y
        self.noise_train = noise
        
        # Normalize input X (using task bounds or data statistics)
        if self.bounds is not None:
            # Use task bounds for normalization
            X_min = self.bounds[:, 0]  # Lower bound
            X_max = self.bounds[:, 1]  # Upper bound
            self.X_mean = (X_min + X_max) / 2  # Bound midpoint
            self.X_std = (X_max - X_min) / 2   # Bound half-width
        else:
            # Use data statistics for normalization (backward compatibility)
            self.X_mean = X.mean(axis=0)
            self.X_std = X.std(axis=0) + 1e-6  # Prevent division by zero
        
        X_normalized = (X - self.X_mean) / self.X_std
        
        # y and noise keep original scale
        y_train = y
        noise_train = noise
        
        # Convert to tensor
        train_x = torch.tensor(X_normalized, dtype=torch.float32)
        train_y = torch.tensor(y_train, dtype=torch.float32)
        train_noise = torch.tensor(noise_train, dtype=torch.float32)
        
        # Check noise situation
        noise_is_zero = torch.all(train_noise == 0.0)
        
        if noise_is_zero:
            # Observation noise is 0, use standard GaussianLikelihood to let model learn noise
            self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
            self.model = StandardGPModel(train_x, train_y, self.likelihood)
        else:
            # Has observation noise, use FixedNoiseGaussianLikelihood
            # Ensure noise is positive
            train_noise = torch.clamp(train_noise, min=1e-6)
            self.likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(
                noise=train_noise,
                learn_additional_noise=True  # Allow learning additional noise
            )
            self.model = FixedNoiseGPModel(train_x, train_y, train_noise, self.likelihood)
        
        # Train model
        self.model.train()
        self.likelihood.train()
        
        # More reasonable optimizer settings
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.05)  # Lower learning rate
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
        
        # Optimize hyperparameters (using configured iteration count)
        for i in range(self.training_iter):
            optimizer.zero_grad()
            output = self.model(train_x)
            loss = -mll(output, train_y)
            loss.backward()
            optimizer.step()
        
        # Training complete, no longer constrain hyperparameters
        
        likelihood_type = "GaussianLikelihood" if noise_is_zero else "FixedNoiseGaussianLikelihood"
    
    def predict(self, X: np.ndarray, return_std: bool = False):
        """Predict.
        
        Args:
            X: Input points, shape (n, d) or (d,)
            return_std: Whether to return standard deviation
            
        Returns:
            If return_std=False: return mean
            If return_std=True: return (mean, standard deviation)
        """
        if self.model is None:
            raise ValueError("Model has not been trained")
        
        # Handle single point case
        single_point = False
        if X.ndim == 1:
            X = X.reshape(1, -1)
            single_point = True
        
        # Normalize input X
        X_normalized = (X - self.X_mean) / self.X_std
        
        # Convert to tensor
        test_x = torch.tensor(X_normalized, dtype=torch.float32)
        
        # Predict
        self.model.eval()
        self.likelihood.eval()
        
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            # Use posterior variance (excluding observation noise), reflects model uncertainty
            posterior_pred = self.model(test_x)
            mean = posterior_pred.mean.numpy()  # y is already in original scale, no need to denormalize
            
            if return_std:
                variance = posterior_pred.variance.numpy()  # posterior variance, excluding observation noise
                # Add minimum variance to balance decisions
                min_variance = 1e-4  # Set minimum variance threshold
                variance = np.maximum(variance, min_variance)
                std = np.sqrt(variance)
                
                if single_point:
                    return float(mean[0]), float(std[0])
                else:
                    return mean, std
            else:
                if single_point:
                    return float(mean[0])
                else:
                    return mean
    
    def predict_with_variance(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Predict mean and variance.
        
        Args:
            X: Input points, shape (n, d)
            
        Returns:
            (mean, variance) Mean and variance arrays
        """
        if self.model is None:
            raise ValueError("Model has not been trained")
        
        # Handle single point case
        single_point = False
        if X.ndim == 1:
            X = X.reshape(1, -1)
            single_point = True
        
        # Normalize input X
        X_normalized = (X - self.X_mean) / self.X_std
        
        # Convert to tensor
        test_x = torch.tensor(X_normalized, dtype=torch.float32)
        
        # Predict
        self.model.eval()
        self.likelihood.eval()
        
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            # Use posterior variance (excluding observation noise), reflects model uncertainty
            posterior_pred = self.model(test_x)
            mean = posterior_pred.mean.numpy()  # y is already in original scale, no need to denormalize
            variance = posterior_pred.variance.numpy()  # posterior variance, excluding observation noise
            
            # Add minimum variance to balance decisions
            min_variance = 1e-4  # Set minimum variance threshold
            variance = np.maximum(variance, min_variance)
            
            if single_point:
                return float(mean[0]), float(variance[0])
            else:
                return mean, variance
