import numpy as np
from abc import ABC, abstractmethod
from typing import Callable, Optional
from dataclasses import dataclass
from scipy.spatial import ConvexHull


@dataclass
class EstimatorConfig:
    """
    Configuration for gradient estimators.
    
    Attributes
    ----------
    estimator_type : str
        Type of estimator ("uniform", "center_gaussian", "random_perturb")
    fd_radius : float
        Base radius for finite difference approximation
    n_samples : int
        Number of random directions to sample
    gradient_estimator : str
        Type of gradient estimator ("standard" or "symmetric")
    random_radius_min : float
        Minimum radius for random perturbation (used by RandomPerturb)
    random_radius_max : float
        Maximum radius for random perturbation (used by RandomPerturb)
    gaussian_std : float
        Standard deviation for Gaussian direction sampling (used by CenterGaussian)
    sample_fraction : float
        Fraction of interior nodes to transform (used by ConvexCombination)
    """
    estimator_type: str = "uniform"
    fd_radius: float = 0.05
    n_samples: int = 20
    gradient_estimator: str = "standard"
    random_radius_min: float = 0.01
    random_radius_max: float = 0.1
    gaussian_std: float = 1.0
    sample_fraction: float = 0.3


class GradientEstimator(ABC):
    """
    Abstract base class for gradient estimators.
    
    This class defines the interface for different gradient estimation methods
    used in zeroth-order optimization.
    """
    
    def __init__(self, config: EstimatorConfig, loss_function: Callable):
        """
        Initialize the gradient estimator.
        
        Parameters
        ----------
        config : EstimatorConfig
            Configuration for the estimator
        loss_function : Callable
            Function that computes the loss given parameters
        """
        self.config = config
        self.loss_function = loss_function
    
    @abstractmethod
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction for gradient estimation.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random direction vector
        """
        pass
    
    @abstractmethod
    def get_fd_radius(self) -> float:
        """
        Get the finite difference radius for this sample.
        
        Returns
        -------
        float
            Finite difference radius
        """
        pass
    
    def __call__(self, params: np.ndarray) -> np.ndarray:
        """
        Estimate the gradient at the given parameters.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
            
        Returns
        -------
        np.ndarray
            Estimated gradient
        """
        n_params = len(params)
        gradient = np.zeros(n_params)
        
        # Evaluate function at current point (needed for standard estimator)
        f_x = self.loss_function(params) if self.config.gradient_estimator == "standard" else None
        
        # Sample random directions and estimate gradient
        for _ in range(self.config.n_samples):
            # Sample direction and radius
            u = self.sample_direction(n_params)
            fd_radius = self.get_fd_radius()
            
            if self.config.gradient_estimator == "symmetric":
                # Symmetric gradient estimator
                x_plus = params + fd_radius * u
                x_minus = params - fd_radius * u
                f_plus = self.loss_function(x_plus)
                f_minus = self.loss_function(x_minus)
                
                # Compute directional derivative using central difference
                directional_derivative = (f_plus - f_minus) / (2 * fd_radius)
            else:
                # Standard (forward) gradient estimator
                x_plus = params + fd_radius * u
                f_x_plus = self.loss_function(x_plus)
                
                # Compute directional derivative
                directional_derivative = (f_x_plus - f_x) / fd_radius
            
            # Update gradient estimate
            gradient += directional_derivative * u
        
        # Average over samples
        gradient /= self.config.n_samples
        
        return gradient


class UniformEstimator(GradientEstimator):
    """
    Uniform gradient estimator.
    
    This is the standard estimator that samples directions uniformly
    from the unit sphere and uses a fixed finite difference radius.
    """
    
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction uniformly from the unit sphere.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random unit direction vector
        """
        u = np.random.randn(n_params)
        return u / np.linalg.norm(u)
    
    def get_fd_radius(self) -> float:
        """
        Get the finite difference radius (fixed for this estimator).
        
        Returns
        -------
        float
            Finite difference radius
        """
        return self.config.fd_radius


class CenterGaussianEstimator(GradientEstimator):
    """
    Center Gaussian gradient estimator.
    
    This estimator uses Gaussian random variables for the direction
    and always uses the symmetric (two-sided) gradient estimation.
    """
    
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction from a Gaussian distribution.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random Gaussian direction vector (not normalized)
        """
        return np.random.normal(0, self.config.gaussian_std, n_params)
    
    def get_fd_radius(self) -> float:
        """
        Get the finite difference radius (fixed for this estimator).
        
        Returns
        -------
        float
            Finite difference radius
        """
        return self.config.fd_radius
    
    def __call__(self, params: np.ndarray) -> np.ndarray:
        """
        Estimate the gradient using Gaussian directions.
        
        Note: This estimator always uses symmetric gradient estimation.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
            
        Returns
        -------
        np.ndarray
            Estimated gradient
        """
        n_params = len(params)
        gradient = np.zeros(n_params)
        
        # Sample random directions and estimate gradient
        for _ in range(self.config.n_samples):
            # Sample Gaussian direction (not normalized)
            u = self.sample_direction(n_params)
            fd_radius = self.get_fd_radius()
            
            # Always use symmetric gradient estimator for CenterGaussian
            x_plus = params + fd_radius * u
            x_minus = params - fd_radius * u
            f_plus = self.loss_function(x_plus)
            f_minus = self.loss_function(x_minus)
            
            # Compute directional derivative using central difference
            directional_derivative = (f_plus - f_minus) / (2 * fd_radius)
            
            # Update gradient estimate
            gradient += directional_derivative * u
        
        # Average over samples
        gradient /= self.config.n_samples
        
        return gradient


class RandomPerturbEstimator(GradientEstimator):
    """
    Random perturbation gradient estimator.
    
    This estimator is similar to the uniform estimator but uses
    a random finite difference radius for each sample.
    """
    
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction uniformly from the unit sphere.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random unit direction vector
        """
        u = np.random.randn(n_params)
        return u / np.linalg.norm(u)
    
    def get_fd_radius(self) -> float:
        """
        Get a random finite difference radius.
        
        Returns
        -------
        float
            Random finite difference radius
        """
        return np.random.uniform(self.config.random_radius_min, self.config.random_radius_max)


class MeshAwareEstimator(GradientEstimator):
    """
    Mesh-aware gradient estimator.
    
    This estimator uses mesh topology information to ensure that perturbations
    do not cause nodes to move outside the convex hull of their neighbors,
    which could lead to mesh inversion or other instabilities.
    """
    
    def __init__(self, config: EstimatorConfig, loss_function: Callable, mesh_info: dict):
        """
        Initialize the mesh-aware estimator.
        
        Parameters
        ----------
        config : EstimatorConfig
            Configuration for the estimator
        loss_function : Callable
            Function that computes the loss given parameters
        mesh_info : dict
            Dictionary containing mesh topology information:
            - 'connectivity': dict mapping node indices to list of neighbor indices
            - 'boundary_mask': boolean array indicating boundary nodes
            - 'coords': initial coordinates of the mesh nodes
        """
        super().__init__(config, loss_function)
        self.mesh_info = mesh_info
        self.shrinkage_count = 0  # Track how many times we shrink the step size
        self.shrinkage_history = []  # Track shrinkage per iteration
        self.current_iteration_shrinkages = 0
    
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction uniformly from the unit sphere.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random unit direction vector
        """
        u = np.random.randn(n_params)
        return u / np.linalg.norm(u)
    
    def get_fd_radius(self) -> float:
        """
        Get the finite difference radius (fixed for this estimator).
        
        Returns
        -------
        float
            Finite difference radius
        """
        return self.config.fd_radius
    
    def _is_in_convex_hull(self, point: np.ndarray, hull_points: np.ndarray) -> bool:
        """
        Check if a point is inside the convex hull of a set of points.
        
        Parameters
        ----------
        point : np.ndarray
            2D point to check
        hull_points : np.ndarray
            Array of shape (n_points, 2) defining the convex hull
            
        Returns
        -------
        bool
            True if point is inside the convex hull, False otherwise
        """
        if len(hull_points) < 3:
            # If less than 3 points, we can't form a proper convex hull
            # Check if the point is close to any of the hull points
            distances = np.linalg.norm(hull_points - point, axis=1)
            return np.min(distances) < 1e-10
        
        try:
            hull = ConvexHull(hull_points)
            # Check if point is inside the convex hull
            # We do this by checking if the point is on the correct side of all faces
            for simplex in hull.equations:
                # simplex format: [a, b, c] where ax + by + c <= 0 defines inside
                if simplex[0] * point[0] + simplex[1] * point[1] + simplex[2] > 1e-10:
                    return False
            return True
        except:
            # If ConvexHull fails (e.g., degenerate case), be conservative
            return False
    
    def _check_mesh_validity(self, params: np.ndarray) -> tuple[bool, list]:
        """
        Check if the mesh with given parameters is valid (no nodes cross edges).
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector representing mesh coordinates
            
        Returns
        -------
        tuple[bool, list]
            (is_valid, violating_nodes) where is_valid is True if mesh is valid,
            and violating_nodes is a list of node indices that violate constraints
        """
        n_nodes = len(params) // 2
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:]
        current_coords = np.column_stack([x_coords, y_coords])
        
        violating_nodes = []
        
        for node_idx in range(n_nodes):
            # Skip boundary nodes as they don't move
            if self.mesh_info['boundary_mask'][node_idx]:
                continue
                
            # Get neighbors of this node
            neighbors = self.mesh_info['connectivity'].get(node_idx, [])
            if len(neighbors) == 0:
                continue
                
            # Get coordinates of neighbors
            neighbor_coords = current_coords[neighbors]
            
            # Check if current node is within convex hull of its neighbors
            current_node_coord = current_coords[node_idx]
            
            if not self._is_in_convex_hull(current_node_coord, neighbor_coords):
                violating_nodes.append(node_idx)
        
        return len(violating_nodes) == 0, violating_nodes
    
    def _get_valid_perturbation(self, params: np.ndarray, direction: np.ndarray, 
                               initial_radius: float) -> tuple[np.ndarray, float, int]:
        """
        Get a valid perturbation that doesn't violate mesh constraints.
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector
        direction : np.ndarray
            Perturbation direction
        initial_radius : float
            Initial perturbation radius
            
        Returns
        -------
        tuple[np.ndarray, float, int]
            (perturbed_params, actual_radius, shrinkage_count)
        """
        radius = initial_radius
        shrinkage_count = 0
        max_shrinkages = 10  # Prevent infinite loop
        
        while shrinkage_count < max_shrinkages:
            perturbed_params = params + radius * direction
            is_valid, _ = self._check_mesh_validity(perturbed_params)
            
            if is_valid:
                return perturbed_params, radius, shrinkage_count
            
            # Shrink radius and try again
            radius *= 0.5
            shrinkage_count += 1
            self.shrinkage_count += 1
        
        # If we can't find a valid perturbation, return the original params
        return params, 0.0, shrinkage_count
    
    def __call__(self, params: np.ndarray) -> np.ndarray:
        """
        Estimate the gradient using mesh-aware perturbations.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
            
        Returns
        -------
        np.ndarray
            Estimated gradient
        """
        n_params = len(params)
        gradient = np.zeros(n_params)
        
        # Reset iteration shrinkage counter
        self.current_iteration_shrinkages = 0
        
        # Evaluate function at current point (needed for standard estimator)
        f_x = self.loss_function(params) if self.config.gradient_estimator == "standard" else None
        
        # Sample random directions and estimate gradient
        for _ in range(self.config.n_samples):
            # Sample direction and radius
            u = self.sample_direction(n_params)
            fd_radius = self.get_fd_radius()
            
            if self.config.gradient_estimator == "symmetric":
                # Symmetric gradient estimator
                x_plus, actual_radius_plus, shrinkages_plus = self._get_valid_perturbation(
                    params, u, fd_radius)
                x_minus, actual_radius_minus, shrinkages_minus = self._get_valid_perturbation(
                    params, -u, fd_radius)
                
                self.current_iteration_shrinkages += shrinkages_plus + shrinkages_minus
                
                if actual_radius_plus > 0 and actual_radius_minus > 0:
                    f_plus = self.loss_function(x_plus)
                    f_minus = self.loss_function(x_minus)
                    
                    # Use average radius for derivative computation
                    avg_radius = (actual_radius_plus + actual_radius_minus) / 2
                    directional_derivative = (f_plus - f_minus) / (2 * avg_radius)
                else:
                    # Skip this sample if we couldn't find valid perturbations
                    continue
                    
            else:
                # Standard (forward) gradient estimator
                x_plus, actual_radius, shrinkages = self._get_valid_perturbation(
                    params, u, fd_radius)
                
                self.current_iteration_shrinkages += shrinkages
                
                if actual_radius > 0:
                    f_x_plus = self.loss_function(x_plus)
                    directional_derivative = (f_x_plus - f_x) / actual_radius
                else:
                    # Skip this sample if we couldn't find a valid perturbation
                    continue
            
            # Update gradient estimate
            gradient += directional_derivative * u
        
        # Average over samples
        gradient /= self.config.n_samples
        
        # Record shrinkages for this iteration
        self.shrinkage_history.append(self.current_iteration_shrinkages)
        
        return gradient


class RejectionMethodEstimator(GradientEstimator):
    """
    Rejection method gradient estimator.
    
    This estimator uses mesh topology information to check if perturbations
    would cause nodes to move outside the convex hull of their neighbors.
    Instead of shrinking the perturbation like MeshAwareEstimator, it completely
    rejects invalid perturbations and outputs 0 gradient for those samples.
    """
    
    def __init__(self, config: EstimatorConfig, loss_function: Callable, mesh_info: dict):
        """
        Initialize the rejection method estimator.
        
        Parameters
        ----------
        config : EstimatorConfig
            Configuration for the estimator
        loss_function : Callable
            Function that computes the loss given parameters
        mesh_info : dict
            Dictionary containing mesh topology information:
            - 'connectivity': dict mapping node indices to list of neighbor indices
            - 'boundary_mask': boolean array indicating boundary nodes
            - 'coords': initial coordinates of the mesh nodes
        """
        super().__init__(config, loss_function)
        self.mesh_info = mesh_info
        self.rejection_count = 0  # Track how many samples were rejected
        self.rejection_history = []  # Track rejections per iteration
        self.current_iteration_rejections = 0
    
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction uniformly from the unit sphere.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random unit direction vector
        """
        u = np.random.randn(n_params)
        return u / np.linalg.norm(u)
    
    def get_fd_radius(self) -> float:
        """
        Get the finite difference radius (fixed for this estimator).
        
        Returns
        -------
        float
            Finite difference radius
        """
        return self.config.fd_radius
    
    def _is_in_convex_hull(self, point: np.ndarray, hull_points: np.ndarray) -> bool:
        """
        Check if a point is inside the convex hull of a set of points.
        
        Parameters
        ----------
        point : np.ndarray
            2D point to check
        hull_points : np.ndarray
            Array of shape (n_points, 2) defining the convex hull
            
        Returns
        -------
        bool
            True if point is inside the convex hull, False otherwise
        """
        if len(hull_points) < 3:
            # If less than 3 points, we can't form a proper convex hull
            # Check if the point is close to any of the hull points
            distances = np.linalg.norm(hull_points - point, axis=1)
            return np.min(distances) < 1e-10
        
        try:
            hull = ConvexHull(hull_points)
            # Check if point is inside the convex hull
            # We do this by checking if the point is on the correct side of all faces
            for simplex in hull.equations:
                # simplex format: [a, b, c] where ax + by + c <= 0 defines inside
                if simplex[0] * point[0] + simplex[1] * point[1] + simplex[2] > 1e-10:
                    return False
            return True
        except:
            # If ConvexHull fails (e.g., degenerate case), be conservative
            return False
    
    def _check_mesh_validity(self, params: np.ndarray) -> tuple[bool, list]:
        """
        Check if the mesh with given parameters is valid (no nodes cross edges).
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector representing mesh coordinates
            
        Returns
        -------
        tuple[bool, list]
            (is_valid, violating_nodes) where is_valid is True if mesh is valid,
            and violating_nodes is a list of node indices that violate constraints
        """
        n_nodes = len(params) // 2
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:]
        current_coords = np.column_stack([x_coords, y_coords])
        
        violating_nodes = []
        
        for node_idx in range(n_nodes):
            # Skip boundary nodes as they don't move
            if self.mesh_info['boundary_mask'][node_idx]:
                continue
                
            # Get neighbors of this node
            neighbors = self.mesh_info['connectivity'].get(node_idx, [])
            if len(neighbors) == 0:
                continue
                
            # Get coordinates of neighbors
            neighbor_coords = current_coords[neighbors]
            
            # Check if current node is within convex hull of its neighbors
            current_node_coord = current_coords[node_idx]
            
            if not self._is_in_convex_hull(current_node_coord, neighbor_coords):
                violating_nodes.append(node_idx)
        
        return len(violating_nodes) == 0, violating_nodes
    
    def _is_valid_perturbation(self, params: np.ndarray, direction: np.ndarray, 
                               radius: float) -> bool:
        """
        Check if a perturbation is valid (doesn't violate mesh constraints).
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector
        direction : np.ndarray
            Perturbation direction
        radius : float
            Perturbation radius
            
        Returns
        -------
        bool
            True if perturbation is valid, False otherwise
        """
        perturbed_params = params + radius * direction
        is_valid, _ = self._check_mesh_validity(perturbed_params)
        return is_valid
    
    def __call__(self, params: np.ndarray) -> np.ndarray:
        """
        Estimate the gradient using rejection method.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
            
        Returns
        -------
        np.ndarray
            Estimated gradient
        """
        n_params = len(params)
        gradient = np.zeros(n_params)
        
        # Reset iteration rejection counter
        self.current_iteration_rejections = 0
        
        # Evaluate function at current point (needed for standard estimator)
        f_x = self.loss_function(params) if self.config.gradient_estimator == "standard" else None
        
        # Sample random directions and estimate gradient
        valid_samples = 0
        for _ in range(self.config.n_samples):
            # Sample direction and radius
            u = self.sample_direction(n_params)
            fd_radius = self.get_fd_radius()
            
            if self.config.gradient_estimator == "symmetric":
                # Symmetric gradient estimator
                x_plus = params + fd_radius * u
                x_minus = params - fd_radius * u
                
                # Check if both perturbations are valid
                if (self._is_valid_perturbation(params, u, fd_radius) and 
                    self._is_valid_perturbation(params, -u, fd_radius)):
                    
                    f_plus = self.loss_function(x_plus)
                    f_minus = self.loss_function(x_minus)
                    
                    # Compute directional derivative using central difference
                    directional_derivative = (f_plus - f_minus) / (2 * fd_radius)
                    
                    # Update gradient estimate
                    gradient += directional_derivative * u
                    valid_samples += 1
                else:
                    # Reject this sample - perturbation would cause mesh crossing
                    self.current_iteration_rejections += 1
                    self.rejection_count += 1
                    continue
                    
            else:
                # Standard (forward) gradient estimator
                x_plus = params + fd_radius * u
                
                # Check if perturbation is valid
                if self._is_valid_perturbation(params, u, fd_radius):
                    f_x_plus = self.loss_function(x_plus)
                    directional_derivative = (f_x_plus - f_x) / fd_radius
                    
                    # Update gradient estimate
                    gradient += directional_derivative * u
                    valid_samples += 1
                else:
                    # Reject this sample - perturbation would cause mesh crossing
                    self.current_iteration_rejections += 1
                    self.rejection_count += 1
                    continue
        
        # Average over valid samples only
        if valid_samples > 0:
            gradient /= valid_samples
        # If no valid samples, gradient remains zero
        
        # Record rejections for this iteration
        self.rejection_history.append(self.current_iteration_rejections)
        
        return gradient


class ConvexCombinationEstimator(GradientEstimator):
    """
    Convex combination gradient estimator.
    
    This estimator transforms selected mesh nodes to convex combination coordinates
    relative to their neighbors, performs gradient estimation in this transformed space,
    and then transforms back to absolute coordinates.
    """
    
    def __init__(self, config: EstimatorConfig, loss_function: Callable, mesh_info: dict):
        """
        Initialize the convex combination estimator.
        
        Parameters
        ----------
        config : EstimatorConfig
            Configuration for the estimator
        loss_function : Callable
            Function that computes the loss given parameters
        mesh_info : dict
            Dictionary containing mesh topology information:
            - 'connectivity': dict mapping node indices to list of neighbor indices
            - 'boundary_mask': boolean array indicating boundary nodes
            - 'coords': initial coordinates of the mesh nodes
        """
        super().__init__(config, loss_function)
        self.mesh_info = mesh_info
        self.transformation_count = 0  # Track number of transformations
        self.transformation_history = []  # Track transformations per iteration
        self.current_iteration_transformations = 0
    
    def sample_direction(self, n_params: int) -> np.ndarray:
        """
        Sample a random direction uniformly from the unit sphere.
        
        Parameters
        ----------
        n_params : int
            Number of parameters
            
        Returns
        -------
        np.ndarray
            Random unit direction vector
        """
        u = np.random.randn(n_params)
        return u / np.linalg.norm(u)
    
    def get_fd_radius(self) -> float:
        """
        Get the finite difference radius (fixed for this estimator).
        
        Returns
        -------
        float
            Finite difference radius
        """
        return self.config.fd_radius
    
    def _is_in_convex_hull(self, point: np.ndarray, hull_points: np.ndarray) -> bool:
        """
        Check if a point is inside the convex hull of a set of points.
        
        Parameters
        ----------
        point : np.ndarray
            2D point to check
        hull_points : np.ndarray
            Array of shape (n_points, 2) defining the convex hull
            
        Returns
        -------
        bool
            True if point is inside the convex hull, False otherwise
        """
        if len(hull_points) < 3:
            # If less than 3 points, we can't form a proper convex hull
            # Check if the point is close to any of the hull points
            distances = np.linalg.norm(hull_points - point, axis=1)
            return np.min(distances) < 1e-10
        
        try:
            hull = ConvexHull(hull_points)
            # Check if point is inside the convex hull
            # We do this by checking if the point is on the correct side of all faces
            for simplex in hull.equations:
                # simplex format: [a, b, c] where ax + by + c <= 0 defines inside
                if simplex[0] * point[0] + simplex[1] * point[1] + simplex[2] > 1e-10:
                    return False
            return True
        except:
            # If ConvexHull fails (e.g., degenerate case), be conservative
            return False
    
    def _solve_convex_combination(self, target_point: np.ndarray, neighbor_points: np.ndarray) -> tuple[np.ndarray, bool]:
        """
        Solve for convex combination weights given target point and neighbor points.
        
        We want to find weights w such that:
        target_point = sum(w_i * neighbor_points[i])
        sum(w_i) = 1
        w_i >= 0
        
        Parameters
        ----------
        target_point : np.ndarray
            2D point to express as convex combination
        neighbor_points : np.ndarray
            Array of shape (n_neighbors, 2) containing neighbor coordinates
            
        Returns
        -------
        tuple[np.ndarray, bool]
            (weights, is_valid) where weights are for convex combination and 
            is_valid indicates if the point is actually inside the convex hull
        """
        n_neighbors = len(neighbor_points)
        
        if n_neighbors == 0:
            return np.array([]), False
        
        if n_neighbors == 1:
            # Only one neighbor, point must be exactly at neighbor location
            if np.allclose(target_point, neighbor_points[0], atol=1e-10):
                return np.array([1.0]), True
            else:
                return np.array([1.0]), False
        
        # Check if the point is inside the convex hull using the existing method
        if not self._is_in_convex_hull(target_point, neighbor_points):
            return np.ones(n_neighbors) / n_neighbors, False
        
        # For 2D case, we solve the linear system:
        # [neighbor_points^T, ones] * weights = [target_point, 1]
        # where neighbor_points^T is the transpose of neighbor coordinates
        
        # Set up the constraint matrix: [x1, x2, ..., xn; y1, y2, ..., yn; 1, 1, ..., 1]
        A = np.vstack([neighbor_points.T, np.ones(n_neighbors)])
        b = np.hstack([target_point, 1.0])
        
        # Solve the least squares problem
        try:
            weights, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
            
            # Check if the solution is valid (non-negative weights)
            if np.all(weights >= -1e-10):  # Allow small numerical errors
                # Project to positive weights
                weights = np.maximum(weights, 0)
                
                # Normalize to ensure sum = 1
                if np.sum(weights) > 0:
                    weights = weights / np.sum(weights)
                else:
                    # If all weights are zero, use uniform weights
                    weights = np.ones(n_neighbors) / n_neighbors
                    
                # Verify the solution
                reconstructed = np.sum(weights[:, np.newaxis] * neighbor_points, axis=0)
                if np.allclose(reconstructed, target_point, atol=1e-10):
                    return weights, True
                else:
                    return np.ones(n_neighbors) / n_neighbors, False
            else:
                # Negative weights indicate point is outside convex hull
                return np.ones(n_neighbors) / n_neighbors, False
                
        except np.linalg.LinAlgError:
            # If singular, use uniform weights
            return np.ones(n_neighbors) / n_neighbors, False
    
    def _transform_to_convex_coordinates(self, params: np.ndarray, sample_nodes: list) -> tuple[np.ndarray, dict]:
        """
        Transform selected nodes to convex combination coordinates.
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector (absolute coordinates)
        sample_nodes : list
            List of node indices to transform
            
        Returns
        -------
        tuple[np.ndarray, dict]
            (transformed_params, transformation_info) where transformation_info
            contains the weights and neighbor information needed for inverse transform
        """
        n_nodes = len(params) // 2
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:]
        current_coords = np.column_stack([x_coords, y_coords])
        
        # Copy original parameters
        transformed_params = params.copy()
        transformation_info = {}
        
        for node_idx in sample_nodes:
            # Skip boundary nodes
            if self.mesh_info['boundary_mask'][node_idx]:
                continue
                
            # Get neighbors of this node
            neighbors = self.mesh_info['connectivity'].get(node_idx, [])
            if len(neighbors) == 0:
                continue
                
            # Get coordinates of neighbors and current node
            neighbor_coords = current_coords[neighbors]
            current_node_coord = current_coords[node_idx]
            
            # Solve for convex combination weights
            weights, is_valid = self._solve_convex_combination(current_node_coord, neighbor_coords)
            
            # Only transform if the node is actually inside the convex hull
            if is_valid:
                # Store transformation info for inverse transform
                transformation_info[node_idx] = {
                    'neighbors': neighbors,
                    'weights': weights,
                    'original_coord': current_node_coord.copy()
                }
            
        return transformed_params, transformation_info
    
    def _transform_from_convex_coordinates(self, params: np.ndarray, transformation_info: dict) -> np.ndarray:
        """
        Transform back from convex combination coordinates to absolute coordinates.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector in convex combination coordinate system
        transformation_info : dict
            Information about the transformation (weights, neighbors, etc.)
            
        Returns
        -------
        np.ndarray
            Parameter vector in absolute coordinates
        """
        n_nodes = len(params) // 2
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:]
        current_coords = np.column_stack([x_coords, y_coords])
        
        # Copy parameters
        absolute_params = params.copy()
        
        for node_idx, info in transformation_info.items():
            neighbors = info['neighbors']
            weights = info['weights']
            
            # Get current neighbor coordinates
            neighbor_coords = current_coords[neighbors]
            
            # Compute new absolute position as convex combination
            new_position = np.sum(weights[:, np.newaxis] * neighbor_coords, axis=0)
            
            # Update the absolute coordinates
            absolute_params[node_idx] = new_position[0]  # x coordinate
            absolute_params[n_nodes + node_idx] = new_position[1]  # y coordinate
        
        return absolute_params
    
    def _sample_nodes_for_transformation(self, n_nodes: int, sample_fraction: float = 0.3) -> list:
        """
        Sample nodes for transformation.
        
        Parameters
        ----------
        n_nodes : int
            Total number of nodes
        sample_fraction : float
            Fraction of nodes to sample
            
        Returns
        -------
        list
            List of node indices to transform
        """
        # Don't sample boundary nodes
        interior_nodes = [i for i in range(n_nodes) if not self.mesh_info['boundary_mask'][i]]
        
        # Sample a fraction of interior nodes
        n_sample = max(1, int(len(interior_nodes) * sample_fraction))
        sample_indices = np.random.choice(interior_nodes, size=n_sample, replace=False)
        
        return sample_indices.tolist()
    
    def __call__(self, params: np.ndarray) -> np.ndarray:
        """
        Estimate the gradient using convex combination coordinates.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
            
        Returns
        -------
        np.ndarray
            Estimated gradient
        """
        n_params = len(params)
        n_nodes = n_params // 2
        gradient = np.zeros(n_params)
        
        # Reset iteration transformation counter
        self.current_iteration_transformations = 0
        
        # Sample nodes for transformation
        sample_nodes = self._sample_nodes_for_transformation(n_nodes, self.config.sample_fraction)
        
        # Transform to convex combination coordinates
        transformed_params, transformation_info = self._transform_to_convex_coordinates(params, sample_nodes)
        self.current_iteration_transformations += len(sample_nodes)
        
        # Evaluate function at current point (needed for standard estimator)
        f_x = self.loss_function(params) if self.config.gradient_estimator == "standard" else None
        
        # Sample random directions and estimate gradient
        for _ in range(self.config.n_samples):
            # Sample direction and radius
            u = self.sample_direction(n_params)
            fd_radius = self.get_fd_radius()
            
            if self.config.gradient_estimator == "symmetric":
                # Symmetric gradient estimator
                # Apply perturbation in transformed space
                perturbed_params_plus = transformed_params + fd_radius * u
                perturbed_params_minus = transformed_params - fd_radius * u
                # if perturbation sum to 1 and above 0, then we skip.
                if (np.sum(perturbed_params_plus) <= 1 or np.sum(perturbed_params_minus) <= 1) and np.all(perturbed_params_plus > 0) and np.all(perturbed_params_minus > 0):
                    pass 
                else:
                    # apply the projection.
                    perturbed_params_plus = np.exp(perturbed_params_plus) / np.sum(np.exp(perturbed_params_plus))
                    perturbed_params_minus = np.exp(perturbed_params_minus) / np.sum(np.exp(perturbed_params_minus))
                
                # Transform back to absolute coordinates
                x_plus = self._transform_from_convex_coordinates(perturbed_params_plus, transformation_info)
                x_minus = self._transform_from_convex_coordinates(perturbed_params_minus, transformation_info)
                
                # Evaluate function
                f_plus = self.loss_function(x_plus)
                f_minus = self.loss_function(x_minus)
                
                # Compute directional derivative using central difference
                directional_derivative = (f_plus - f_minus) / (2 * fd_radius)
            else:
                # Standard (forward) gradient estimator
                # Apply perturbation in transformed space
                perturbed_params = transformed_params + fd_radius * u
                
                # Transform back to absolute coordinates
                x_plus = self._transform_from_convex_coordinates(perturbed_params, transformation_info)
                
                # Evaluate function
                f_x_plus = self.loss_function(x_plus)
                
                # Compute directional derivative
                directional_derivative = (f_x_plus - f_x) / fd_radius
            
            # Update gradient estimate
            gradient += directional_derivative * u
        
        # Average over samples
        gradient /= self.config.n_samples
        
        # Record transformations for this iteration
        self.transformation_history.append(self.current_iteration_transformations)
        self.transformation_count += self.current_iteration_transformations
        
        return gradient


def get_estimator(config: EstimatorConfig, loss_function: Callable, mesh_info: dict = None) -> GradientEstimator:
    """
    Factory function to create the appropriate gradient estimator.
    
    Parameters
    ----------
    config : EstimatorConfig
        Configuration specifying the estimator type and parameters
    loss_function : Callable
        Function that computes the loss given parameters
    mesh_info : dict, optional
        Mesh topology information required for MeshAwareEstimator and RejectionMethodEstimator
        
    Returns
    -------
    GradientEstimator
        The requested gradient estimator instance
        
    Raises
    ------
    ValueError
        If the estimator type is not recognized or required parameters are missing
    """
    estimator_type = config.estimator_type.lower()
    
    if estimator_type == "uniform":
        return UniformEstimator(config, loss_function)
    elif estimator_type == "center_gaussian":
        return CenterGaussianEstimator(config, loss_function)
    elif estimator_type == "random_perturb":
        return RandomPerturbEstimator(config, loss_function)
    elif estimator_type == "mesh_aware":
        if mesh_info is None:
            raise ValueError("mesh_info is required for MeshAwareEstimator")
        return MeshAwareEstimator(config, loss_function, mesh_info)
    elif estimator_type == "rejection_method":
        if mesh_info is None:
            raise ValueError("mesh_info is required for RejectionMethodEstimator")
        return RejectionMethodEstimator(config, loss_function, mesh_info)
    elif estimator_type == "convex_combination":
        if mesh_info is None:
            raise ValueError("mesh_info is required for ConvexCombinationEstimator")
        return ConvexCombinationEstimator(config, loss_function, mesh_info)
    else:
        raise ValueError(f"Unknown estimator type: {config.estimator_type}") 