import numpy as np
import warnings
import sklearn
from sklearn.cluster import DBSCAN
from functools import wraps

class SafeObservationModel:
    """
    A safe wrapper for the observation model that handles numerical issues.
    
    This fixes the "scale < 0" error by ensuring the noise scale is always positive.
    """
    
    def __init__(self, original_model, min_noise_scale=1e-6):
        """
        Initialize the safe observation model wrapper.
        
        Args:
            original_model: The original observation model function
            min_noise_scale: Minimum allowed noise scale value
        """
        self.original_model = original_model
        self.min_noise_scale = min_noise_scale
        
    def __call__(self, state, observation=None):
        """
        Call the observation model with safety checks.
        
        Args:
            state: The state to get observation for
            observation: Optional actual observation for likelihood computation
            
        Returns:
            Observation or likelihood, depending on the original model
        """
        try:
            # Handle potential NaN in state
            if isinstance(state, np.ndarray) and np.isnan(state).any():
                # Replace NaNs with zeros or some sensible default value
                clean_state = np.nan_to_num(state, nan=0.0, posinf=1e10, neginf=-1e10)
            else:
                clean_state = state
                
            return self.original_model(clean_state, observation)
        except ValueError as e:
            if "scale < 0" in str(e):
                # Save the original normal function
                orig_normal = np.random.normal
                
                # Monkey patch np.random.normal to ensure positive scale
                def safe_normal(loc=0.0, scale=1.0, size=None):
                    safe_scale = max(abs(scale), self.min_noise_scale)
                    return orig_normal(loc, safe_scale, size)
                
                # Replace temporarily
                np.random.normal = safe_normal
                
                try:
                    # Try again with safe function
                    result = self.original_model(clean_state, observation)
                finally:
                    # Restore original function regardless of outcome
                    np.random.normal = orig_normal
                    
                return result
            else:
                # Other ValueError, re-raise
                raise


# Create a class to safely wrap DBSCAN to handle NaNs
class SafeDBSCAN:
    """
    A safe wrapper for DBSCAN that handles NaN values by filtering them out.
    """
    def __init__(self, *args, **kwargs):
        self.dbscan = DBSCAN(*args, **kwargs)
        
    def fit(self, X, sample_weight=None):
        # Filter out rows with NaN values
        if isinstance(X, np.ndarray):
            valid_mask = ~np.isnan(X).any(axis=1)
            if np.sum(valid_mask) == 0:
                # If no valid data points, return dummy results
                self.labels_ = np.array([0])
                self.core_sample_indices_ = np.array([0])
                return self
                
            X_filtered = X[valid_mask]
            if sample_weight is not None:
                sample_weight_filtered = sample_weight[valid_mask]
            else:
                sample_weight_filtered = None
                
            # Run DBSCAN on filtered data
            self.dbscan.fit(X_filtered, sample_weight_filtered)
            
            # Map results back to original size
            self.labels_ = np.zeros(len(X), dtype=int) - 1  # -1 is the DBSCAN "noise" label
            self.labels_[valid_mask] = self.dbscan.labels_
            
            # Adjust core_sample_indices_ to the original indexing
            original_indices = np.arange(len(X))[valid_mask]
            self.core_sample_indices_ = original_indices[self.dbscan.core_sample_indices_]
            
            return self
        else:
            # Not a numpy array, try to proceed as normal
            return self.dbscan.fit(X, sample_weight)
    
    def fit_predict(self, X, sample_weight=None):
        self.fit(X, sample_weight)
        return self.labels_


# Create a decorator to patch sklearn functions
def patch_sklearn_estimator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except ValueError as e:
            if "contains NaN" in str(e):
                warnings.warn(f"NaN values detected in sklearn function. Applying automatic cleaning. Original error: {e}")
                # Extract the input data from args (assuming it's the first argument)
                if len(args) > 0:
                    X = args[0]
                    if isinstance(X, np.ndarray):
                        # Replace NaNs with zeros
                        X_clean = np.nan_to_num(X, nan=0.0)
                        new_args = (X_clean,) + args[1:]
                        return func(*new_args, **kwargs)
            # If it's not a NaN error or we can't fix it, re-raise
            raise
    return wrapper


def patch_experiment_runner(runner):
    """
    Patch the KidnappedRobotRunner to use safe observation models and handle numerical issues.
    
    Args:
        runner: The experiment runner to patch
        
    Returns:
        The patched runner
    """
    # Wrap the observation model with our safe version
    if hasattr(runner, 'observation_model'):
        runner.observation_model = SafeObservationModel(runner.observation_model)
    
    # Apply safe environment patching - do it in-place
    if hasattr(runner, 'env') and runner.env is not None:
        runner.env = make_safe_environment(runner.env)
        
    # Monkey-patch any method that might use DBSCAN
    if hasattr(runner, 'cluster_particles'):
        original_cluster_method = runner.cluster_particles
        
        def safe_cluster_particles(*args, **kwargs):
            try:
                # Replace any DBSCAN instantiations with SafeDBSCAN
                if 'DBSCAN' in str(original_cluster_method):
                    from sklearn.cluster import DBSCAN as OriginalDBSCAN
                    # Temporarily monkey patch DBSCAN
                    sklearn.cluster.DBSCAN = SafeDBSCAN
                
                result = original_cluster_method(*args, **kwargs)
                
                # Restore original DBSCAN
                if 'DBSCAN' in str(original_cluster_method):
                    sklearn.cluster.DBSCAN = OriginalDBSCAN
                    
                return result
            except Exception as e:
                warnings.warn(f"Error in cluster_particles: {e}. Returning default clustering.")
                # Return a default clustering (all points in one cluster)
                particles = args[0] if args else kwargs.get('particles', [])
                return np.zeros(len(particles), dtype=int), None
        
        runner.cluster_particles = safe_cluster_particles
    
    # If there are methods in runner that calculate metrics, patch them
    if hasattr(runner, 'calculate_mode_coverage'):
        original_mode_coverage = runner.calculate_mode_coverage
        
        def safe_mode_coverage(*args, **kwargs):
            try:
                result = original_mode_coverage(*args, **kwargs)
                # Ensure result is valid
                if isinstance(result, (float, int)) and (np.isnan(result) or np.isinf(result)):
                    return 0.0  # Default to 0 coverage for invalid results
                return result
            except Exception as e:
                warnings.warn(f"Error calculating mode coverage: {e}. Returning default value.")
                return 0.0
                
        runner.calculate_mode_coverage = safe_mode_coverage
        
    return runner


def make_safe_environment(env):
    """
    Patch the KidnappedRobotEnv to use safe numerical operations.
    
    Args:
        env: The environment to patch
        
    Returns:
        The patched environment
    """
    if env is None:
        return None
    
    # Monkey-patch the step function to ensure states remain valid
    if hasattr(env, 'step'):
        original_step = env.step
        
        def safe_step(*args, **kwargs):
            state, reward, done, info = original_step(*args, **kwargs)
            
            # Clean any invalid values in state
            if isinstance(state, np.ndarray) and (np.isnan(state).any() or np.isinf(state).any()):
                warnings.warn("NaN or Inf found in environment state. Sanitizing values.")
                state = np.nan_to_num(state, nan=0.0, posinf=1e10, neginf=-1e10)
                
            return state, reward, done, info
            
        env.step = safe_step
    
    return env


def apply_global_patches():
    """
    Apply global patches to numpy's random functions for numerical stability.
    """
    # Save original numpy functions
    orig_normal = np.random.normal
    orig_mv_normal = np.random.multivariate_normal
    
    # Define safe versions
    def safe_normal(loc=0.0, scale=1.0, size=None):
        """Ensure scale is always positive"""
        safe_scale = max(abs(scale), 1e-6)
        return orig_normal(loc, safe_scale, size)
    
    def is_psd(matrix, tol=1e-8):
        """Check if a matrix is positive semidefinite"""
        # Check symmetry
        is_sym = np.allclose(matrix, matrix.T, rtol=1e-5, atol=1e-8)
        
        if not is_sym:
            return False, False, -np.inf
        
        # Check eigenvalues
        try:
            eigvals = np.linalg.eigvalsh(matrix)
            min_eigval = np.min(eigvals)
            is_psd = min_eigval >= -tol
            return is_sym, is_psd, min_eigval
        except np.linalg.LinAlgError:
            return is_sym, False, -np.inf
    
    def manual_mvn_sample(mean, L, size=None):
        """
        Sample from multivariate normal using the Cholesky factor directly.
        This avoids numerical issues in numpy's multivariate_normal.
        
        Args:
            mean: Mean vector
            L: Lower triangular Cholesky factor of covariance
            size: Sample size
            
        Returns:
            Samples from N(mean, L@L.T)
        """
        dim = len(mean)
        
        if size is None:
            # Single sample
            z = np.random.normal(size=dim)
            return mean + L @ z
        elif np.isscalar(size):
            # Multiple samples, same shape as mean
            z = np.random.normal(size=(size, dim))
            return mean + z @ L.T
        else:
            # Multiple samples with specific shape
            flat_size = np.prod(size)
            z = np.random.normal(size=(flat_size, dim))
            samples = mean + z @ L.T
            return samples.reshape((*size, dim))
    
    def safe_mv_normal(mean, cov, size=None):
        """
        Safe multivariate normal sampling with robust PSD enforcement.
        
        This version targets the specific issue of very small negative eigenvalues
        by using a direct sampling approach.
        """
        try:
            # Clean the input mean and ensure we have a proper covariance matrix
            if isinstance(mean, np.ndarray):
                mean = np.nan_to_num(mean, nan=0.0)
            
            dim = len(mean)
            
            if not isinstance(cov, np.ndarray) or cov.shape != (dim, dim):
                # If not a proper covariance matrix, use a simple diagonal one
                simple_cov = np.eye(dim) * 0.01
                return orig_mv_normal(mean, simple_cov, size)
            
            # Ensure the matrix is symmetric - critical for eigendecomposition
            cov_sym = (cov + cov.T) / 2
            
            # Replace any NaN or Inf values
            cov_sym = np.nan_to_num(cov_sym, nan=1e-8, posinf=1.0, neginf=-1.0)
            
            # Compute eigendecomposition
            eigvals, eigvecs = np.linalg.eigh(cov_sym)
            
            # Find the threshold for fixing eigenvalues
            # We'll use a relative threshold based on the maximum eigenvalue
            # This handles the specific case we're seeing with very small negative values
            max_eigval = np.max(eigvals)
            min_eigval = np.min(eigvals)
            
            # If min eigenvalue is negative but very small compared to max eigenvalue,
            # use a relative threshold, otherwise use an absolute minimum
            if min_eigval < 0 and abs(min_eigval) < max_eigval * 0.001:
                # Small negative eigenvalues compared to the largest eigenvalue
                # Set minimum eigenvalue to a small fraction of the maximum
                eigval_floor = max(1e-10, max_eigval * 0.0001)
            else:
                # Use a fixed minimum eigenvalue for more severe cases
                eigval_floor = 1e-8
            
            # Fix any non-positive eigenvalues
            fixed_eigvals = np.maximum(eigvals, eigval_floor)
            
            # Reconstruct the matrix
            cov_psd = eigvecs @ np.diag(fixed_eigvals) @ eigvecs.T
            
            # Ensure symmetry again after reconstruction
            cov_psd = (cov_psd + cov_psd.T) / 2
            
            # Try to compute the Cholesky decomposition for sampling
            try:
                # Use Cholesky for a more efficient and numerically stable sampling
                L = np.linalg.cholesky(cov_psd)
                
                # Sample using our custom function that avoids further numerical issues
                return manual_mvn_sample(mean, L, size)
                
            except np.linalg.LinAlgError:
                # If Cholesky still fails, add more aggressive jitter to the diagonal
                # and try one more time
                for i in range(5):
                    jitter = 10.0**(i-6)  # From 1e-6 up to 1e-2
                    try:
                        L = np.linalg.cholesky(cov_psd + np.eye(dim) * jitter)
                        return manual_mvn_sample(mean, L, size)
                    except np.linalg.LinAlgError:
                        continue
                
                # If all else fails, fall back to a diagonal covariance
                diag_cov = np.diag(np.maximum(np.diag(cov_psd), 1e-4))
                return orig_mv_normal(mean, diag_cov, size)
            
        except Exception as e:
            # If there's any unhandled error, use a simple diagonal covariance
            print(f"Warning: Error in multivariate_normal, using fallback. Error: {e}")
            dim = len(mean)
            simple_cov = np.eye(dim) * 0.01
            return orig_mv_normal(mean, simple_cov, size)
    
    # Apply the patches
    np.random.normal = safe_normal
    np.random.multivariate_normal = safe_mv_normal
    
    print("Applied numerical safety patches")


def patch_svgd_kernels():
    """
    Patch SVGD kernel computations to handle numerical issues.
    This specifically targets the overflow in square and invalid value in subtract 
    warnings shown in the logs.
    """
    try:
        import importlib
        
        # Try to import ESCORT's kernel module
        try:
            kernel_module = importlib.import_module('pomdps.escort.utils.kernels')
            
            # Check if we have the compute_distances function
            if hasattr(kernel_module, 'compute_distances'):
                original_compute_distances = kernel_module.compute_distances
                
                def safe_compute_distances(x, y):
                    try:
                        # Normalize inputs to prevent overflow
                        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
                            # Clean inputs
                            x = np.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
                            y = np.nan_to_num(y, nan=0.0, posinf=1e10, neginf=-1e10)
                            
                            # Scale down if values are too large
                            max_val = max(np.max(np.abs(x)), np.max(np.abs(y)))
                            if max_val > 1e5:
                                scale_factor = 1e5 / max_val
                                x = x * scale_factor
                                y = y * scale_factor
                        
                        result = original_compute_distances(x, y)
                        
                        # Clean output
                        if isinstance(result, np.ndarray):
                            result = np.nan_to_num(result, nan=1e-10, posinf=1e10, neginf=0.0)
                            # Ensure distances are non-negative
                            result = np.maximum(result, 0.0)
                        
                        return result
                    except Exception as e:
                        warnings.warn(f"Error in compute_distances: {e}. Using manual implementation.")
                        # Fallback implementation
                        n = x.shape[0]
                        m = y.shape[0]
                        
                        # Compute squared norms safely
                        x_norm = np.zeros((n, 1))
                        for i in range(n):
                            x_norm[i] = np.sum(x[i] * x[i])
                        
                        y_norm = np.zeros((1, m))
                        for j in range(m):
                            y_norm[0, j] = np.sum(y[j] * y[j])
                        
                        # Compute dot products safely
                        dot_prod = np.zeros((n, m))
                        for i in range(n):
                            for j in range(m):
                                dot_prod[i, j] = np.sum(x[i] * y[j])
                        
                        # Compute distances
                        dist_mat = x_norm + y_norm - 2 * dot_prod
                        
                        # Clean up any numerical issues
                        dist_mat = np.maximum(dist_mat, 0.0)  # Ensure non-negative
                        dist_mat = np.nan_to_num(dist_mat, nan=1e-10)  # Replace NaNs
                        
                        return dist_mat
                
                # Replace the function
                kernel_module.compute_distances = safe_compute_distances
                
            # Check for RBF kernel
            if hasattr(kernel_module, 'rbf_kernel'):
                original_rbf = kernel_module.rbf_kernel
                
                def safe_rbf_kernel(x, y, h=1.0):
                    try:
                        # Clean inputs
                        if isinstance(x, np.ndarray):
                            x = np.nan_to_num(x, nan=0.0)
                        if isinstance(y, np.ndarray):
                            y = np.nan_to_num(y, nan=0.0)
                            
                        # Use a safe distance calculation
                        n = x.shape[0]
                        m = y.shape[0]
                        result = np.zeros((n, m))
                        
                        # Calculate pairwise distances carefully
                        for i in range(n):
                            for j in range(m):
                                diff = x[i] - y[j]
                                sq_dist = np.sum(diff * diff)
                                result[i, j] = np.exp(-sq_dist / h)
                        
                        return result
                    except Exception as e:
                        warnings.warn(f"Error in safe_rbf_kernel: {e}. Using original with cleaned inputs.")
                        # Try original with cleaned inputs
                        if isinstance(x, np.ndarray):
                            x = np.nan_to_num(x, nan=0.0)
                        if isinstance(y, np.ndarray):
                            y = np.nan_to_num(y, nan=0.0)
                        
                        result = original_rbf(x, y, h)
                        
                        # Clean output
                        if isinstance(result, np.ndarray):
                            result = np.nan_to_num(result, nan=0.0)
                        
                        return result
                
                # Replace the function
                kernel_module.rbf_kernel = safe_rbf_kernel
        
        except (ImportError, AttributeError) as e:
            warnings.warn(f"Could not patch ESCORT kernels: {e}")
    
    except Exception as e:
        warnings.warn(f"Error in patch_svgd_kernels: {e}")
