
import numpy as np
from scipy.spatial.distance import cdist
from typing import Dict, Tuple, Optional, Union


def _safe_eigh_psd(S: np.ndarray, ridge: float = 1e-8):
    """
    Eigen-decomposition stable for SPD/PSD matrices.
    Clips eigenvalues to be >= eps.
    """
    S = 0.5 * (S + S.T)  # symmetrize
    evals, evecs = np.linalg.eigh(S)
    eps = ridge * max(1.0, float(np.max(evals)))
    evals_clipped = np.clip(evals, eps, None)
    return evals_clipped, evecs


class ConditionalTR:
    """
    Transformation conditionnelle Y -> Y_whitened basée sur Sigma(x).
    
    Pour chaque point x, on estime une covariance locale Sigma(x) via:
      - KNN: moyenne des résidus pour les k plus proches voisins de x
      - Kernel: somme pondérée par un noyau gaussien K(x, x_i)
    
    La transformation est alors:
      Y_whitened = invSqrtS(x) @ (Y - mu)
    
    Et la retransformation:
      Y_original = sqrtS(x) @ Y_whitened + mu
    
    Attributes:
        method: "knn" ou "kernel"
        k: nombre de voisins pour KNN
        bandwidth: largeur du noyau gaussien (auto si None)
        ridge: régularisation pour la covariance
    """
    
    def __init__(
        self,
        method: str = "knn",
        k: int = 50,
        bandwidth: float = None,
        ridge: float = 1e-6,
    ):
        """
        Initialize ConditionalTR.
        
        Args:
            method: "knn" or "kernel"
            k: number of neighbors for KNN method
            bandwidth: kernel bandwidth (auto-computed if None)
            ridge: regularization for covariance matrix
        """
        self.method = method
        self.k = k
        self.bandwidth = bandwidth
        self.ridge = ridge
        
        # Fitted data (set in fit())
        self.X_train: Optional[np.ndarray] = None
        self.Y_train: Optional[np.ndarray] = None
        self.residuals: Optional[np.ndarray] = None
        self.mu_global: Optional[np.ndarray] = None
        self.d_x: int = 0
        self.d_y: int = 0
        
        # Cache for efficiency
        self._cache: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
        self._cache_max_size: int = 1000
        
    def fit(
        self, 
        X: np.ndarray, 
        Y: np.ndarray, 
        Y_pred: Optional[np.ndarray] = None
    ) -> "ConditionalTR":
        """
        Fit the conditional TR model.
        
        Args:
            X: (n, d_x) input features
            Y: (n, d_y) targets
            Y_pred: (n, d_y) predictions mu(x) from a model (optional)
                    If None, uses global mean as center.
        
        Returns:
            self (fitted)
        """
        self.X_train = X.astype(np.float32)
        self.Y_train = Y.astype(np.float32)
        self.d_x = X.shape[1]
        self.d_y = Y.shape[1]
        
        # Compute residuals: Y - mu(X)
        if Y_pred is not None:
            self.residuals = (Y - Y_pred).astype(np.float32)
            self.mu_global = None  # mu is model-based, not stored
        else:
            self.mu_global = Y.mean(axis=0).astype(np.float32)
            self.residuals = (Y - self.mu_global).astype(np.float32)
        
        # Auto bandwidth for kernel method (Silverman's rule adapted)
        if self.method == "kernel" and self.bandwidth is None:
            n = X.shape[0]
            # Scott's rule for multivariate: h = n^(-1/(d+4)) * std
            self.bandwidth = (n ** (-1.0 / (self.d_x + 4))) * X.std()
            # Ensure minimum bandwidth
            self.bandwidth = max(self.bandwidth, 1e-3)
        
        # Clear cache
        self._cache = {}
        
        print(f"  [ConditionalTR] Fitted: method={self.method}, k={self.k}, "
              f"n_train={X.shape[0]}, d_x={self.d_x}, d_y={self.d_y}")
        if self.method == "kernel":
            print(f"    bandwidth={self.bandwidth:.4f}")
        
        return self
    
    def _get_weights(self, x: np.ndarray) -> np.ndarray:
        """
        Compute weights for each training point relative to query x.
        
        Args:
            x: (d_x,) single query point
            
        Returns:
            weights: (n,) weights summing to 1
        """
        x = x.flatten().reshape(1, -1)  # (1, d_x)
        
        # Distances to all training points
        dists = cdist(x, self.X_train, metric='euclidean').flatten()  # (n,)
        
        if self.method == "knn":
            # K-nearest neighbors: uniform weights on k closest
            k_actual = min(self.k, len(dists))
            idx_k = np.argpartition(dists, k_actual - 1)[:k_actual]
            weights = np.zeros(len(dists), dtype=np.float32)
            weights[idx_k] = 1.0 / k_actual
        else:
            # Gaussian kernel: K(x, x_i) = exp(-||x - x_i||^2 / (2 * h^2))
            weights = np.exp(-0.5 * (dists / self.bandwidth) ** 2)
            weights = weights / (weights.sum() + 1e-12)
        
        return weights.astype(np.float32)
    
    def estimate_local_cov(
        self, 
        x: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Estimate Sigma(x), sqrtS(x), and invSqrtS(x) for a single point x.
        
        The local covariance is computed as a weighted average of outer products:
            Sigma(x) = sum_i w_i(x) * r_i * r_i^T
        
        where w_i(x) are the KNN or kernel weights.
        
        Args:
            x: (d_x,) query point
            
        Returns:
            Sigma: (d_y, d_y) local covariance matrix
            sqrtS: (d_y, d_y) matrix square root
            invSqrtS: (d_y, d_y) inverse square root (for whitening)
        """
        x = np.atleast_1d(x).flatten()
        
        # Check cache
        x_key = tuple(x.round(6))
        if x_key in self._cache:
            return self._cache[x_key]
        
        # Get weights
        weights = self._get_weights(x)  # (n,)
        
        # Weighted covariance: Sigma(x) = R^T @ diag(w) @ R
        # Efficient computation: (R * sqrt(w))^T @ (R * sqrt(w))
        R = self.residuals  # (n, d_y)
        sqrt_w = np.sqrt(weights)[:, None]  # (n, 1)
        R_weighted = R * sqrt_w  # (n, d_y)
        Sigma = R_weighted.T @ R_weighted  # (d_y, d_y)
        
        # Add ridge regularization (scaled by trace for stability)
        scale = float(np.trace(Sigma)) / max(1, self.d_y)
        Sigma = Sigma + (self.ridge * max(1e-12, scale)) * np.eye(self.d_y)
        
        # Eigendecomposition for stable square root
        eigvals, eigvecs = _safe_eigh_psd(Sigma, ridge=self.ridge)
        
        sqrtS = (eigvecs * np.sqrt(eigvals)) @ eigvecs.T
        invSqrtS = (eigvecs * (1.0 / np.sqrt(eigvals))) @ eigvecs.T
        
        result = (
            Sigma.astype(np.float32), 
            sqrtS.astype(np.float32), 
            invSqrtS.astype(np.float32)
        )
        
        # Cache result (with size limit)
        if len(self._cache) < self._cache_max_size:
            self._cache[x_key] = result
        
        return result
    
    def transform(
        self, 
        X: np.ndarray, 
        Y: np.ndarray,
        verbose: bool = False
    ) -> np.ndarray:
        """
        Transform Y to whitened space using local Sigma(x) for each point.
        
        y_whitened[i] = invSqrtS(x[i]) @ (y[i] - mu)
        
        Args:
            X: (n, d_x) input features
            Y: (n, d_y) targets
            verbose: print progress
            
        Returns:
            Y_transformed: (n, d_y) whitened targets
        """
        X = np.atleast_2d(X).astype(np.float32)
        Y = np.atleast_2d(Y).astype(np.float32)
        n = len(X)
        
        Y_transformed = np.zeros_like(Y)
        
        for i in range(n):
            _, _, invSqrtS = self.estimate_local_cov(X[i])
            
            if self.mu_global is not None:
                Y_transformed[i] = invSqrtS @ (Y[i] - self.mu_global)
            else:
                Y_transformed[i] = invSqrtS @ Y[i]
            
            if verbose and (i + 1) % 1000 == 0:
                print(f"    Transform: {i+1}/{n}")
        
        return Y_transformed.astype(np.float32)
    
    def inverse_transform(
        self, 
        X: np.ndarray, 
        Y_whitened: np.ndarray,
        verbose: bool = False
    ) -> np.ndarray:
        """
        Transform back from whitened space to original space.
        
        y_original[i] = sqrtS(x[i]) @ y_whitened[i] + mu
        
        Args:
            X: (n, d_x) input features
            Y_whitened: (n, d_y) whitened targets
            verbose: print progress
            
        Returns:
            Y_original: (n, d_y) original-space targets
        """
        X = np.atleast_2d(X).astype(np.float32)
        Y_whitened = np.atleast_2d(Y_whitened).astype(np.float32)
        n = len(X)
        
        Y_original = np.zeros_like(Y_whitened)
        
        for i in range(n):
            _, sqrtS, _ = self.estimate_local_cov(X[i])
            
            if self.mu_global is not None:
                Y_original[i] = sqrtS @ Y_whitened[i] + self.mu_global
            else:
                Y_original[i] = sqrtS @ Y_whitened[i]
            
            if verbose and (i + 1) % 1000 == 0:
                print(f"    Inverse transform: {i+1}/{n}")
        
        return Y_original.astype(np.float32)
    
    def get_log_det_jacobian(self, x: np.ndarray) -> float:
        """
        Get log|det(sqrtS(x))| for volume correction at point x.
        
        This is needed to correct volumes when comparing regions
        in transformed space vs original space.
        
        Args:
            x: (d_x,) query point
            
        Returns:
            log_det: log of absolute determinant of sqrtS(x)
        """
        _, sqrtS, _ = self.estimate_local_cov(x)
        eigvals = np.abs(np.linalg.eigvalsh(sqrtS))
        return float(np.sum(np.log(np.clip(eigvals, 1e-30, None))))
    
    def get_condition_number(self, x: np.ndarray) -> float:
        """
        Get condition number of Sigma(x) for diagnostics.
        
        Args:
            x: (d_x,) query point
            
        Returns:
            cond: condition number (ratio of max/min eigenvalues)
        """
        Sigma, _, _ = self.estimate_local_cov(x)
        eigvals = np.abs(np.linalg.eigvalsh(Sigma))
        eigvals = np.clip(eigvals, 1e-30, None)
        return float(eigvals.max() / eigvals.min())


# ==============================================================================
# Convenience functions (similar to tr_retr.py interface)
# ==============================================================================

def fit_conditional_tr_and_standardize(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    method: str = "knn",
    k: int = 50,
    bandwidth: float = None,
    ridge: float = 1e-6,
) -> Tuple[ConditionalTR, np.ndarray]:
    """
    Fit conditional TR on training data and return transformed Y.
    
    This is a convenience function similar to fit_tr_and_standardize
    but for conditional transformation.
    
    Args:
        X_train: (n, d_x) training features
        Y_train: (n, d_y) training targets
        method: "knn" or "kernel"
        k: number of neighbors for KNN
        bandwidth: kernel bandwidth (auto if None)
        ridge: regularization parameter
        
    Returns:
        tr_cond: fitted ConditionalTR object
        Y_transformed: (n, d_y) transformed training targets
    """
    tr_cond = ConditionalTR(
        method=method,
        k=k,
        bandwidth=bandwidth,
        ridge=ridge,
    )
    tr_cond.fit(X_train, Y_train)
    Y_transformed = tr_cond.transform(X_train, Y_train)
    
    return tr_cond, Y_transformed


def conditional_tr_transform(
    X: np.ndarray,
    Y: np.ndarray,
    tr_cond: ConditionalTR,
) -> np.ndarray:
    """
    Transform Y using fitted conditional TR.
    
    Args:
        X: (n, d_x) features
        Y: (n, d_y) targets
        tr_cond: fitted ConditionalTR
        
    Returns:
        Y_transformed: (n, d_y) transformed targets
    """
    return tr_cond.transform(X, Y)


def conditional_tr_retransform(
    X: np.ndarray,
    Y_whitened: np.ndarray,
    tr_cond: ConditionalTR,
) -> np.ndarray:
    """
    Retransform Y from whitened space to original space.
    
    Args:
        X: (n, d_x) features
        Y_whitened: (n, d_y) whitened targets
        tr_cond: fitted ConditionalTR
        
    Returns:
        Y_original: (n, d_y) original-space targets
    """
    return tr_cond.inverse_transform(X, Y_whitened)


# ==============================================================================
# STABILIZED CONDITIONAL TR
# ==============================================================================

class StabilizedConditionalTR(ConditionalTR):
    """
    Enhanced ConditionalTR with shrinkage and fallback for stability.
    
    This class addresses ill-conditioned local covariances by:
    1. Shrinkage: Blend local Σ(x) with global Σ to reduce variance
    2. Adaptive ridge: Scale ridge by local variance
    3. Fallback: Use global Σ when condition number exceeds threshold
    
    Recommended for datasets with heteroscedastic noise where pure local
    estimation can be unstable (e.g., rf1 with high-dimensional output).
    """
    
    def __init__(
        self,
        method: str = "knn",
        k: int = 50,
        bandwidth: float = None,
        ridge: float = 1e-6,
        shrinkage_alpha: float = 0.3,
        fallback_threshold: float = 1e6,
        eigenvalue_shrinkage: float = 0.0,
    ):
        """
        Initialize StabilizedConditionalTR.
        
        Args:
            method: "knn" or "kernel"
            k: number of neighbors for KNN method
            bandwidth: kernel bandwidth (auto-computed if None)
            ridge: base regularization for covariance matrix
            shrinkage_alpha: blend factor with global cov (0=local only, 1=global only)
            fallback_threshold: condition number threshold for fallback to global
            eigenvalue_shrinkage: shrink eigenvalues toward geometric mean (0=none, 1=full)
                                  Higher values = more spherical cov = smaller regions
        """
        super().__init__(method=method, k=k, bandwidth=bandwidth, ridge=ridge)
        self.shrinkage_alpha = shrinkage_alpha
        self.fallback_threshold = fallback_threshold
        self.eigenvalue_shrinkage = eigenvalue_shrinkage
        
        # Global covariance (computed in fit)
        self._Sigma_global: Optional[np.ndarray] = None
        self._sqrtS_global: Optional[np.ndarray] = None
        self._invSqrtS_global: Optional[np.ndarray] = None
        
        # Stats for diagnostics
        self._n_fallbacks: int = 0
        self._total_calls: int = 0
        
    def fit(
        self, 
        X: np.ndarray, 
        Y: np.ndarray, 
        Y_pred: Optional[np.ndarray] = None
    ) -> "StabilizedConditionalTR":
        """Fit with global covariance computation for shrinkage/fallback."""
        # Call parent fit
        super().fit(X, Y, Y_pred)
        
        # Compute global covariance for shrinkage
        R = self.residuals
        n = len(R)
        self._Sigma_global = (R.T @ R) / n
        
        # Add ridge to global
        scale = float(np.trace(self._Sigma_global)) / max(1, self.d_y)
        self._Sigma_global = self._Sigma_global + (self.ridge * max(1e-12, scale)) * np.eye(self.d_y)
        
        # Compute global sqrtS and invSqrtS
        eigvals, eigvecs = _safe_eigh_psd(self._Sigma_global, ridge=self.ridge)
        self._sqrtS_global = (eigvecs * np.sqrt(eigvals)) @ eigvecs.T
        self._invSqrtS_global = (eigvecs * (1.0 / np.sqrt(eigvals))) @ eigvecs.T
        
        # Global condition number for reference
        cond_global = eigvals.max() / eigvals.min()
        print(f"  [StabilizedConditionalTR] shrinkage_alpha={self.shrinkage_alpha:.2f}, "
              f"fallback_thresh={self.fallback_threshold:.0e}")
        print(f"    Global cov condition number: {cond_global:.2e}")
        
        # Reset stats
        self._n_fallbacks = 0
        self._total_calls = 0
        
        return self
    
    def estimate_local_cov(
        self, 
        x: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Estimate Sigma(x) with shrinkage and fallback.
        
        The stabilized covariance is:
            Sigma_stable(x) = (1 - α) * Sigma_local(x) + α * Sigma_global
        
        If condition number of result exceeds threshold, falls back to global.
        """
        x = np.atleast_1d(x).flatten()
        self._total_calls += 1
        
        # Check cache first
        x_key = tuple(x.round(6))
        if x_key in self._cache:
            return self._cache[x_key]
        
        # Get weights for local covariance
        weights = self._get_weights(x)  # (n,)
        
        # Compute local covariance
        R = self.residuals
        sqrt_w = np.sqrt(weights)[:, None]
        R_weighted = R * sqrt_w
        Sigma_local = R_weighted.T @ R_weighted
        
        # Add adaptive ridge (scaled by local trace)
        local_trace = float(np.trace(Sigma_local))
        scale = local_trace / max(1, self.d_y)
        Sigma_local = Sigma_local + (self.ridge * max(1e-12, scale)) * np.eye(self.d_y)
        
        # Apply shrinkage: blend with global
        alpha = self.shrinkage_alpha
        Sigma = (1 - alpha) * Sigma_local + alpha * self._Sigma_global
        
        # Eigendecomposition
        eigvals, eigvecs = _safe_eigh_psd(Sigma, ridge=self.ridge)
        
        # EIGENVALUE SHRINKAGE: contract toward geometric mean for smaller regions
        if self.eigenvalue_shrinkage > 0:
            log_eigvals = np.log(np.clip(eigvals, 1e-30, None))
            geom_mean = np.exp(np.mean(log_eigvals))
            # Shrink toward geometric mean (makes cov more spherical)
            eigvals = (1 - self.eigenvalue_shrinkage) * eigvals + self.eigenvalue_shrinkage * geom_mean
        
        # Check condition number
        cond_num = eigvals.max() / eigvals.min()
        
        if cond_num > self.fallback_threshold:
            # Fallback to global
            self._n_fallbacks += 1
            result = (
                self._Sigma_global.astype(np.float32),
                self._sqrtS_global.astype(np.float32),
                self._invSqrtS_global.astype(np.float32),
            )
        else:
            sqrtS = (eigvecs * np.sqrt(eigvals)) @ eigvecs.T
            invSqrtS = (eigvecs * (1.0 / np.sqrt(eigvals))) @ eigvecs.T
            result = (
                Sigma.astype(np.float32),
                sqrtS.astype(np.float32),
                invSqrtS.astype(np.float32),
            )
        
        # Cache result
        if len(self._cache) < self._cache_max_size:
            self._cache[x_key] = result
        
        return result
    
    def get_fallback_rate(self) -> float:
        """Return fraction of calls that fell back to global covariance."""
        if self._total_calls == 0:
            return 0.0
        return self._n_fallbacks / self._total_calls
    
    def print_diagnostics(self):
        """Print diagnostic summary."""
        print(f"  [StabilizedConditionalTR Diagnostics]")
        print(f"    Total calls: {self._total_calls}")
        print(f"    Fallbacks:   {self._n_fallbacks} ({self.get_fallback_rate()*100:.1f}%)")


def fit_stabilized_conditional_tr_and_standardize(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    method: str = "knn",
    k: int = 50,
    bandwidth: float = None,
    ridge: float = 1e-6,
    shrinkage_alpha: float = 0.3,
    fallback_threshold: float = 1e6,
    eigenvalue_shrinkage: float = 0.0,
) -> Tuple[StabilizedConditionalTR, np.ndarray]:
    """
    Fit stabilized conditional TR on training data and return transformed Y.
    
    Args:
        X_train: (n, d_x) training features
        Y_train: (n, d_y) training targets
        method: "knn" or "kernel"
        k: number of neighbors for KNN
        bandwidth: kernel bandwidth (auto if None)
        ridge: regularization parameter
        shrinkage_alpha: blend factor with global cov (0=local, 1=global)
        fallback_threshold: condition number threshold for fallback
        eigenvalue_shrinkage: shrink eigenvalues toward geometric mean (0=none, 1=full)
        
    Returns:
        tr_cond: fitted StabilizedConditionalTR object
        Y_transformed: (n, d_y) transformed training targets
    """
    tr_cond = StabilizedConditionalTR(
        method=method,
        k=k,
        bandwidth=bandwidth,
        ridge=ridge,
        shrinkage_alpha=shrinkage_alpha,
        fallback_threshold=fallback_threshold,
        eigenvalue_shrinkage=eigenvalue_shrinkage,
    )
    tr_cond.fit(X_train, Y_train)
    Y_transformed = tr_cond.transform(X_train, Y_train)
    
    return tr_cond, Y_transformed
