"""Residual GP model - standard GP (pure GPyTorch implementation)"""
import torch
import numpy as np
import gpytorch
from typing import Tuple


class StandardGPModel(gpytorch.models.ExactGP):
    """Standard GP model."""
    
    def __init__(self, train_x, train_y, likelihood):
        super(StandardGPModel, self).__init__(train_x, train_y, likelihood)
        
        # Use standard GP modules
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1])
        )
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)




class ResidualGP:
    """Residual GP model - models δ = y_H - ρ·μ_LF."""
    
    def __init__(self, training_iter: int = 100, bounds: np.ndarray = None):
        """Initialize residual GP model.
        
        Args:
            training_iter: Number of GP hyperparameter optimization iterations (default 100)
            bounds: Input space bounds, shape (d, 2), each row is [min, max]
        """
        self.model = None
        self.likelihood = None
        self.X_train = None
        self.residuals_train = None
        
        # Normalization parameters (using task bounds or data statistics)
        self.X_mean = None
        self.X_std = None
        self.bounds = bounds  # Task bounds
        
        # Training parameters
        self.training_iter = training_iter
    
    def fit(self, X: np.ndarray, residuals: np.ndarray):
        """Train residual GP model.
        
        Args:
            X: Training inputs, shape (n, d)
            residuals: Residuals δ = y_H - ρ·μ_LF, shape (n,)
        """
        if len(X) == 0:
            raise ValueError("Training data cannot be empty")
        
        # Ensure correct data types (handle object type from pandas)
        X = np.asarray(X, dtype=np.float64)
        residuals = np.asarray(residuals, dtype=np.float64)
        
        # Save original training data
        self.X_train = X
        self.residuals_train = residuals
        
        # Normalize input X (using task bounds or data statistics)
        if self.bounds is not None:
            # Use task bounds for normalization
            X_min = self.bounds[:, 0]  # Lower bound
            X_max = self.bounds[:, 1]  # Upper bound
            self.X_mean = (X_min + X_max) / 2  # Bound midpoint
            self.X_std = (X_max - X_min) / 2   # Bound half-width
        else:
            # Use data statistics for normalization (backward compatibility)
            self.X_mean = X.mean(axis=0)
            self.X_std = X.std(axis=0) + 1e-6  # Prevent division by zero
        
        X_normalized = (X - self.X_mean) / self.X_std
        
        # residuals keep original scale
        residuals_train = residuals
        
        # Convert to tensor
        train_x = torch.tensor(X_normalized, dtype=torch.float64)
        train_y = torch.tensor(residuals_train, dtype=torch.float64)
        
        # Create likelihood
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        
        # Create standard residual GP model (independent hyperparameter training)
        self.model = StandardGPModel(train_x, train_y, self.likelihood)
        
        # Train model
        self.model.train()
        self.likelihood.train()
        
        # Optimize all parameters
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.05)
        
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
        
        # Optimize hyperparameters (using configured iteration count)
        for i in range(self.training_iter):
            optimizer.zero_grad()
            output = self.model(train_x)
            loss = -mll(output, train_y)
            loss.backward()
            optimizer.step()
        
        # Training complete, no longer constrain hyperparameters
        
        return
    
    def predict(self, X: np.ndarray, return_std: bool = False):
        """Predict residuals.
        
        Args:
            X: Input points, shape (n, d) or (d,)
            return_std: Whether to return standard deviation
            
        Returns:
            If return_std=False: return mean
            If return_std=True: return (mean, standard deviation)
        """
        if self.model is None:
            raise ValueError("Model has not been trained")
        
        # Handle single point case
        single_point = False
        if X.ndim == 1:
            X = X.reshape(1, -1)
            single_point = True
        
        # Normalize input X
        X_normalized = (X - self.X_mean) / self.X_std
        
        # Convert to tensor
        test_x = torch.tensor(X_normalized, dtype=torch.float64)
        
        # Predict
        self.model.eval()
        self.likelihood.eval()
        
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            # Use posterior variance (excluding observation noise), reflects model uncertainty
            posterior_pred = self.model(test_x)
            mean = posterior_pred.mean.numpy()  # residuals are already in original scale, no need to denormalize
            
            if return_std:
                variance = posterior_pred.variance.numpy()  # posterior variance, excluding observation noise
                std = np.sqrt(variance)
                
                if single_point:
                    return float(mean[0]), float(std[0])
                else:
                    return mean, std
            else:
                if single_point:
                    return float(mean[0])
                else:
                    return mean
    
    def predict_with_variance(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Predict residual mean and variance.
        
        Args:
            X: Input points, shape (n, d)
            
        Returns:
            (mean, variance) Mean and variance arrays
        """
        if self.model is None:
            raise ValueError("Model has not been trained")
        
        # Handle single point case
        single_point = False
        if X.ndim == 1:
            X = X.reshape(1, -1)
            single_point = True
        
        # Normalize input X
        X_normalized = (X - self.X_mean) / self.X_std
        
        # Convert to tensor
        test_x = torch.tensor(X_normalized, dtype=torch.float64)
        
        # Predict
        self.model.eval()
        self.likelihood.eval()
        
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            # Use posterior variance (excluding observation noise), reflects model uncertainty
            posterior_pred = self.model(test_x)
            mean = posterior_pred.mean.numpy()  # residuals are already in original scale, no need to denormalize
            variance = posterior_pred.variance.numpy()  # posterior variance, excluding observation noise
            
            if single_point:
                return float(mean[0]), float(variance[0])
            else:
                return mean, variance
