"""KOH fusion module - fuses low-fidelity and residual predictions"""
import numpy as np
from typing import Tuple
from .models.low_fidelity_gp import LowFidelityGP
from .models.residual_gp import ResidualGP
from .models.rho_manager import RhoManager


class KOHFusion:
    """KOH fusion - fuses low-fidelity GP and residual GP to get posterior high-fidelity predictions."""
    
    def __init__(self, lf_gp: LowFidelityGP, residual_gp: ResidualGP, rho_manager: RhoManager):
        """Initialize KOH fusion.
        
        Args:
            lf_gp: Low-fidelity GP model
            residual_gp: Residual GP model
            rho_manager: ρ manager
        """
        self.lf_gp = lf_gp
        self.residual_gp = residual_gp
        self.rho_manager = rho_manager
    
    def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """KOH posterior prediction.
        
        μ_H = ρ * μ_LF + μ_δ
        σ²_H = ρ² * σ²_LF + σ²_δ
        
        Args:
            X: Input points, shape (n, d) or (d,)
            
        Returns:
            (mean, variance) Posterior mean and variance
        """
        # Handle single point case
        if X.ndim == 1:
            X = X.reshape(1, -1)
            single_point = True
        else:
            single_point = False
        
        # Get ρ
        rho = self.rho_manager.get_rho()
        
        # Low-fidelity prediction
        mu_LF, sigma2_LF = self.lf_gp.predict_with_variance(X)
        
        # Residual prediction
        mu_delta, sigma2_delta = self.residual_gp.predict_with_variance(X)
        
        # KOH fusion
        mu_H = rho * mu_LF + mu_delta
        sigma2_H = rho**2 * sigma2_LF + sigma2_delta
        
        if single_point:
            return float(mu_H[0]), float(sigma2_H[0])
        else:
            return mu_H, sigma2_H
    
    def predict_lf_only(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Predict using only low-fidelity GP (for debugging).
        
        Args:
            X: Input points, shape (n, d)
            
        Returns:
            (mean, variance)
        """
        return self.lf_gp.predict_with_variance(X)
    
    def predict_residual_only(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Predict using only residual GP (for debugging).
        
        Args:
            X: Input points, shape (n, d)
            
        Returns:
            (mean, variance)
        """
        return self.residual_gp.predict_with_variance(X)

