import numpy as np
from typing import Dict, Any, Callable, Optional, Tuple, List, Union
from dataclasses import dataclass, fields
from skfem import MeshTri, Basis, ElementTriP1
import yaml
import os
import wandb
import matplotlib.pyplot as plt
from PIL import Image
from ..mesh_utils import mesh_to_coords, coords_to_mesh, create_uniform_mesh
from ..pde_solver import PDESolver, DirichletBC
from ..error_metrics import compute_mse
from ..estimator import EstimatorConfig, get_estimator, GradientEstimator


@dataclass
class TrainerConfig:
    """
    Configuration for the zeroth-order mesh optimization trainer.
    
    Attributes
    ----------
    step_size : float
        Learning rate for gradient descent
    fd_radius : float
        Radius for finite difference approximation
    n_samples : int
        Number of random directions to sample for gradient estimation
    n_iterations : int
        Number of optimization iterations
    regularization_weight : float
        Weight for regularization term
    eval_points_n : int
        Number of evaluation points for error computation (deprecated - now uses all mesh nodes)
    use_wandb : bool
        Whether to use Weights & Biases for experiment tracking
    wandb_entity : Optional[str]
        Weights & Biases entity name
    wandb_project : Optional[str]
        Weights & Biases project name
    config_file : Optional[str]
        Path to YAML configuration file for this experiment 
    gradient_estimator : str
        Type of gradient estimator to use ("standard" or "symmetric")
    estimator_type : str
        Type of estimator ("uniform", "center_gaussian", "random_perturb", "mesh_aware", "rejection_method")
    random_radius_min : float
        Minimum radius for random perturbation (used by RandomPerturb)
    random_radius_max : float
        Maximum radius for random perturbation (used by RandomPerturb)
    gaussian_std : float
        Standard deviation for Gaussian direction sampling (used by CenterGaussian)
    fine_mesh_resolution : int
        Resolution for fine mesh used for ground truth visualization (default: 50)
    perturbation_scale : float
        Scale for random perturbations added to mesh nodes for MSE evaluation (default: 0.01)
    mesh_aware_update : bool
        Whether to use mesh-aware gradient updates (default: True when using mesh_aware estimator)
    max_update_shrinkages : int
        Maximum number of step size reductions for mesh-aware gradient updates (default: 15)
    update_shrinkage_factor : float
        Factor by which to reduce step size when mesh constraints are violated (default: 0.5)
    sample_fraction : float
        Fraction of points to sample for convex combination
    """
    
    step_size: float = 0.01
    fd_radius: float = 0.05
    n_samples: int = 20
    n_iterations: int = 100
    regularization_weight: float = 0.0001
    eval_points_n: int = 100
    use_wandb: bool = False
    wandb_entity: Optional[str] = None
    wandb_project: Optional[str] = None
    config_file: Optional[str] = None
    gradient_estimator: str = "standard"
    tag: Optional[str] = None
    # New estimator-related parameters
    estimator_type: str = "uniform"
    random_radius_min: float = 0.01
    random_radius_max: float = 0.1
    gaussian_std: float = 1.0
    fine_mesh_resolution: int = 50
    perturbation_scale: float = 0.001
    # New mesh-aware gradient update parameters
    mesh_aware_update: bool = True
    max_update_shrinkages: int = 15
    update_shrinkage_factor: float = 0.5
    # ConvexCombination specific parameters
    sample_fraction: float = 0.3
    
    @classmethod
    def from_yaml(cls, yaml_path: str) -> 'TrainerConfig':
        """
        Create a TrainerConfig instance from a YAML file.
        
        Parameters
        ----------
        yaml_path : str
            Path to the YAML configuration file
            
        Returns
        -------
        TrainerConfig
            Configuration instance with values from the YAML file
        """
        if not os.path.exists(yaml_path):
            raise FileNotFoundError(f"Configuration file not found: {yaml_path}")
            
        with open(yaml_path, 'r') as f:
            config_dict = yaml.safe_load(f)
            
        # Filter out keys that are not attributes of TrainerConfig
        valid_keys = {f.name for f in fields(cls)}
        filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
        
        # Create a new instance with the values from the YAML file
        return cls(**filtered_dict)


class Trainer:
    """
    Zeroth-order mesh optimization trainer.
    
    This class implements a gradient-free optimization approach that treats
    mesh node coordinates as parameters and minimizes the error between the
    PDE solution on the mesh and a ground truth solution. MSE is computed
    using all mesh nodes with small random perturbations as evaluation points.
    
    Attributes
    ----------
    config : TrainerConfig
        Configuration object with hyperparameters
    solver : PDESolver
        PDE solver instance
    bc : DirichletBC
        Boundary condition for the PDE
    ground_truth : Callable
        Ground truth function for computing the error
    original_mesh : MeshTri
        Original mesh topology (used for preserving boundary)
    eval_points : np.ndarray
        Points for error evaluation (deprecated - now uses mesh nodes with perturbations)
    """
    def __init__(
        self,
        config: TrainerConfig,
        solver: PDESolver,
        bc: DirichletBC,
        ground_truth: Callable
    ):
        """
        Initialize the trainer.
        
        Parameters
        ----------
        config : TrainerConfig
            Configuration object with hyperparameters
        solver : PDESolver
            PDE solver instance
        bc : DirichletBC
            Boundary condition for the PDE
        ground_truth : Callable
            Ground truth function for computing the error
        """
        # If config_file is provided in the config, load it and override with provided values
        if config.config_file is not None:
            try:
                # Load configuration from YAML file
                yaml_config = TrainerConfig.from_yaml(config.config_file)
                
                # Create a new config by starting with YAML values
                # and overriding with any explicitly provided non-None values
                for field in fields(config):
                    field_name = field.name
                    # Skip the config_file field to avoid infinite recursion
                    if field_name == 'config_file':
                        continue
                    
                    # Get value from the provided config
                    value = getattr(config, field_name)
                    
                    # If the value is non-default, override the YAML config
                    # This allows command-line args to take precedence over YAML
                    if value != field.default:
                        setattr(yaml_config, field_name, value)
                
                # Use the merged config
                self.config = yaml_config
            except (FileNotFoundError, yaml.YAMLError) as e:
                print(f"Warning: Failed to load config file {config.config_file}: {e}")
                self.config = config
        else:
            self.config = config
            
        self.solver = solver
        self.bc = bc
        self.ground_truth = ground_truth
        self.original_mesh = None
        self.eval_points = None
        self.wandb = wandb
        tag = self.config.tag if self.config.tag else None 
        self.wandb.init(
            entity=self.config.wandb_entity,
            project=self.config.wandb_project,
            config={f.name: getattr(self.config, f.name) for f in fields(self.config)},
            mode="disabled" if not self.config.use_wandb else "online",
            name=tag
        )  
        
        # Initialize visualization tracking
        self.visualization_files = []
        
        # Initialize gradient update shrinkage tracking
        self.update_shrinkage_count = 0
        self.update_shrinkage_history = []
        self.current_iteration_update_shrinkages = 0
    
    def initialize(self, mesh: MeshTri) -> np.ndarray:
        """
        Initialize the optimization with a mesh.
        
        Parameters
        ----------
        mesh : MeshTri
            Initial mesh
            
        Returns
        -------
        np.ndarray
            Initial parameter vector
        """
        self.original_mesh = mesh
        
        # Convert mesh to parameter vector
        return self._mesh_to_params(mesh)
    
    def _mesh_to_params(self, mesh: MeshTri) -> np.ndarray:
        """
        Convert mesh node coordinates to parameter vector.
        
        Parameters
        ----------
        mesh : MeshTri
            Input mesh
            
        Returns
        -------
        np.ndarray
            Parameter vector
        """
        coords = mesh_to_coords(mesh)
        return np.concatenate([coords[:, 0], coords[:, 1]])
    
    def _params_to_mesh(self, params: np.ndarray, original_mesh: MeshTri) -> MeshTri:
        """
        Convert parameter vector back to mesh.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
        original_mesh : MeshTri
            Original mesh with the topology to preserve
            
        Returns
        -------
        MeshTri
            Mesh with updated node coordinates
        """
        n_nodes = original_mesh.p.shape[1]
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:2*n_nodes]
        
        coords = np.column_stack((x_coords, y_coords))
        
        # Preserve boundary nodes
        original_coords = mesh_to_coords(original_mesh)
        boundary_mask = (
            np.isclose(original_coords[:, 0], 0.0) |
            np.isclose(original_coords[:, 0], 1.0) |
            np.isclose(original_coords[:, 1], 0.0) |
            np.isclose(original_coords[:, 1], 1.0)
        )
        
        coords[boundary_mask] = original_coords[boundary_mask]
        
        return coords_to_mesh(coords, original_mesh)
    
    def compute_loss(self, params: np.ndarray) -> float:
        """
        Compute the loss for a parameter vector.
        
        The loss consists of the MSE between the PDE solution and ground truth,
        plus an optional regularization term.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector
            
        Returns
        -------
        float
            Loss value
        """
        # Convert parameters to mesh
        mesh = self._params_to_mesh(params, self.original_mesh)
        
        # Create basis
        basis = Basis(mesh, ElementTriP1())
        
        # Solve PDE
        solution = self.solver.solve(basis, self.bc)
        
        # Compute MSE using all mesh nodes plus small perturbations as evaluation points
        mesh_coords = mesh_to_coords(mesh)  # Get all mesh node coordinates
        
        # Add small random perturbations to mesh coordinates
        perturbation_scale = self.config.perturbation_scale
        perturbations = np.random.normal(0, perturbation_scale, mesh_coords.shape)
        perturbed_coords = mesh_coords + perturbations
        
        # Ensure perturbed points stay within domain [0,1] x [0,1]
        perturbed_coords = np.clip(perturbed_coords, 0.01, 0.99)
        
        mse = compute_mse(solution, perturbed_coords, self.ground_truth)
        
        # Add regularization (penalize large deviations from uniform mesh)
        if self.config.regularization_weight > 0:
            n_nodes = mesh.p.shape[1]
            x_coords = params[:n_nodes].reshape(-1, 1)
            y_coords = params[n_nodes:2*n_nodes].reshape(-1, 1)
            
            coords = np.hstack((x_coords, y_coords))
            
            # Compute pairwise distances between adjacent nodes
            t = mesh.t
            edges = np.vstack([
                np.column_stack((t[0, :], t[1, :])),
                np.column_stack((t[1, :], t[2, :])),
                np.column_stack((t[2, :], t[0, :]))
            ])
            
            # Remove duplicate edges
            edges = np.unique(np.sort(edges, axis=1), axis=0)
            
            # Compute edge lengths
            edge_lengths = np.sqrt(
                np.sum((coords[edges[:, 0]] - coords[edges[:, 1]]) ** 2, axis=1)
            )
            
            # Penalty for non-uniform edge lengths
            mean_length = np.mean(edge_lengths)
            reg_term = np.mean((edge_lengths - mean_length) ** 2)
            
            # Add to MSE
            mse += self.config.regularization_weight * reg_term
        
        return mse
    
    def _compute_mesh_connectivity(self, mesh: MeshTri) -> dict:
        """
        Compute mesh connectivity information.
        
        Parameters
        ----------
        mesh : MeshTri
            Input mesh
            
        Returns
        -------
        dict
            Dictionary containing:
            - 'connectivity': dict mapping node indices to list of neighbor indices
            - 'boundary_mask': boolean array indicating boundary nodes
            - 'coords': initial coordinates of the mesh nodes
        """
        from collections import defaultdict
        
        # Get mesh coordinates
        coords = mesh_to_coords(mesh)
        
        # Compute boundary mask
        boundary_mask = (
            np.isclose(coords[:, 0], 0.0) |
            np.isclose(coords[:, 0], 1.0) |
            np.isclose(coords[:, 1], 0.0) |
            np.isclose(coords[:, 1], 1.0)
        )
        
        # Build connectivity graph from triangles
        connectivity = defaultdict(set)
        triangles = mesh.t  # Shape: (3, n_triangles)
        
        for tri_idx in range(triangles.shape[1]):
            # Get the three vertices of this triangle
            v0, v1, v2 = triangles[:, tri_idx]
            
            # Add edges (each vertex is connected to the other two)
            connectivity[v0].update([v1, v2])
            connectivity[v1].update([v0, v2])
            connectivity[v2].update([v0, v1])
        
        # Convert sets to lists for easier handling
        connectivity = {k: list(v) for k, v in connectivity.items()}
        
        return {
            'connectivity': connectivity,
            'boundary_mask': boundary_mask,
            'coords': coords
        }

    def get_estimator(self) -> 'GradientEstimator':
        """
        Get the gradient estimator based on configuration.
        
        Returns
        -------
        GradientEstimator
            Configured gradient estimator instance
        """
        estimator_config = EstimatorConfig(
            estimator_type=self.config.estimator_type,
            fd_radius=self.config.fd_radius,
            n_samples=self.config.n_samples,
            gradient_estimator=self.config.gradient_estimator,
            random_radius_min=self.config.random_radius_min,
            random_radius_max=self.config.random_radius_max,
            gaussian_std=self.config.gaussian_std,
            sample_fraction=self.config.sample_fraction
        )
        
        # For mesh-aware estimators, compute and pass mesh information
        if self.config.estimator_type.lower() in ["mesh_aware", "rejection_method", "convex_combination"]:
            if self.original_mesh is None:
                raise ValueError("Mesh information not available. Call initialize() first.")
            mesh_info = self._compute_mesh_connectivity(self.original_mesh)
            return get_estimator(estimator_config, self.compute_loss, mesh_info)
        else:
            return get_estimator(estimator_config, self.compute_loss)

    def estimate_gradient(self, params: np.ndarray) -> np.ndarray:
        """
        Estimate the gradient using the configured estimator.
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector
            
        Returns
        -------
        np.ndarray
            Estimated gradient
        """
        estimator = self.get_estimator()
        return estimator(params)
    
    def step(self, params: np.ndarray) -> np.ndarray:
        """
        Perform a single optimization step.
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector
            
        Returns
        -------
        np.ndarray
            Updated parameter vector
        """
        # Estimate gradient
        estimator = self.get_estimator()
        gradient = estimator(params)
        
        # Apply mesh-aware gradient update if enabled
        if self.config.mesh_aware_update and self.config.estimator_type.lower() in ["mesh_aware", "rejection_method", "convex_combination"]:
            new_params = self._mesh_aware_gradient_update(params, gradient)
        else:
            # Standard gradient update
            new_params = params - self.config.step_size * gradient
        
        return new_params
    
    def _mesh_aware_gradient_update(self, params: np.ndarray, gradient: np.ndarray) -> np.ndarray:
        """
        Perform mesh-aware gradient update with constraint checking.
        
        This method ensures that the gradient update doesn't violate mesh topology
        constraints by adaptively reducing the step size when necessary.
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector
        gradient : np.ndarray
            Estimated gradient
            
        Returns
        -------
        np.ndarray
            Updated parameter vector that satisfies mesh constraints
        """
        # Reset iteration shrinkage counter
        self.current_iteration_update_shrinkages = 0
        
        # Initial step size
        step_size = self.config.step_size
        shrinkage_count = 0
        max_shrinkages = self.config.max_update_shrinkages
        shrinkage_factor = self.config.update_shrinkage_factor
        
        # Get mesh connectivity information
        mesh_info = self._compute_mesh_connectivity(self.original_mesh)
        
        while shrinkage_count < max_shrinkages:
            # Compute candidate update
            candidate_params = params - step_size * gradient
            
            # Check if the candidate update violates mesh constraints
            is_valid = self._check_update_validity(candidate_params, mesh_info)
            
            if is_valid:
                # Valid update found
                if shrinkage_count > 0:
                    print(f"    Gradient update step size reduced by factor {(1/shrinkage_factor)**shrinkage_count:.3f}")
                
                # Record shrinkage information
                self.current_iteration_update_shrinkages = shrinkage_count
                self.update_shrinkage_count += shrinkage_count
                
                return candidate_params
            
            # Shrink step size and try again
            step_size *= shrinkage_factor
            shrinkage_count += 1
        
        # If we can't find a valid update, return original parameters
        print(f"    Warning: Could not find valid gradient update after {max_shrinkages} shrinkages")
        
        # Record shrinkage information even for failed updates
        self.current_iteration_update_shrinkages = shrinkage_count
        self.update_shrinkage_count += shrinkage_count
        
        return params
    
    def _check_update_validity(self, params: np.ndarray, mesh_info: dict) -> bool:
        """
        Check if the parameter update results in a valid mesh.
        
        Parameters
        ----------
        params : np.ndarray
            Parameter vector representing mesh coordinates
        mesh_info : dict
            Dictionary containing mesh topology information
            
        Returns
        -------
        bool
            True if the mesh is valid, False otherwise
        """
        n_nodes = len(params) // 2
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:]
        current_coords = np.column_stack([x_coords, y_coords])
        
        # Apply the same boundary constraint as _params_to_mesh
        # This is crucial - we need to check validity AFTER boundary correction
        original_coords = mesh_to_coords(self.original_mesh)
        boundary_mask = mesh_info['boundary_mask']
        current_coords[boundary_mask] = original_coords[boundary_mask]
        
        # Check convex hull constraints for interior nodes only
        for node_idx in range(n_nodes):
            # Skip boundary nodes as they are fixed
            if boundary_mask[node_idx]:
                continue
                
            # Get neighbors of this node
            neighbors = mesh_info['connectivity'].get(node_idx, [])
            if len(neighbors) == 0:
                continue
                
            # Get coordinates of neighbors (including boundary nodes at their fixed positions)
            neighbor_coords = current_coords[neighbors]
            
            # Check if current node is within convex hull of its neighbors
            current_node_coord = current_coords[node_idx]
            
            if not self._is_in_convex_hull(current_node_coord, neighbor_coords):
                return False
        
        return True
    
    def _is_in_convex_hull(self, point: np.ndarray, hull_points: np.ndarray) -> bool:
        """
        Check if a point is inside the convex hull of a set of points.
        
        Parameters
        ----------
        point : np.ndarray
            2D point to check
        hull_points : np.ndarray
            Array of shape (n_points, 2) defining the convex hull
            
        Returns
        -------
        bool
            True if point is inside the convex hull, False otherwise
        """
        from scipy.spatial import ConvexHull
        
        if len(hull_points) < 3:
            # If less than 3 points, we can't form a proper convex hull
            # For gradient updates, allow reasonable movement but not too much
            distances = np.linalg.norm(hull_points - point, axis=1)
            max_edge_length = 0.02  # More restrictive edge length constraint
            return np.min(distances) < max_edge_length
        
        try:
            hull = ConvexHull(hull_points)
            # Check if point is inside the convex hull
            # We do this by checking if the point is on the correct side of all faces
            # Use a small tolerance for numerical stability
            tolerance = 1e-10
            for simplex in hull.equations:
                # simplex format: [a, b, c] where ax + by + c <= 0 defines inside
                if simplex[0] * point[0] + simplex[1] * point[1] + simplex[2] > tolerance:
                    return False
            return True
        except Exception as e:
            # If ConvexHull fails (e.g., degenerate case), be more restrictive
            # Check if point is reasonably close to existing points
            distances = np.linalg.norm(hull_points - point, axis=1)
            max_movement = 0.01  # Allow only small movement in degenerate cases
            return np.min(distances) < max_movement
    
    def compute_mesh_mse(self, params: np.ndarray) -> float:
        """
        Compute MSE between original mesh and current mesh coordinates.
        
        Parameters
        ----------
        params : np.ndarray
            Current parameter vector
            
        Returns
        -------
        float
            MSE between original and current mesh
        """
        # Get original mesh coordinates
        original_coords = mesh_to_coords(self.original_mesh)
        
        # Get current mesh coordinates from params
        n_nodes = len(original_coords)
        x_coords = params[:n_nodes]
        y_coords = params[n_nodes:2*n_nodes]
        current_coords = np.column_stack((x_coords, y_coords))
        
        # Compute MSE (excluding boundary nodes which don't move)
        boundary_mask = (
            np.isclose(original_coords[:, 0], 0.0) |
            np.isclose(original_coords[:, 0], 1.0) |
            np.isclose(original_coords[:, 1], 0.0) |
            np.isclose(original_coords[:, 1], 1.0)
        )
        interior_mask = ~boundary_mask
        
        if np.any(interior_mask):
            return np.mean(np.sum((original_coords[interior_mask] - current_coords[interior_mask])**2, axis=1))
        else:
            return 0.0
    
    def _create_ground_truth_solution_object(self, mesh: MeshTri) -> Any:
        """
        Create a solution-like object for the ground truth function on a given mesh.
        
        Parameters
        ----------
        mesh : MeshTri
            Mesh to evaluate ground truth on
            
        Returns
        -------
        Any
            Solution-like object with mesh and value attributes
        """
        from types import SimpleNamespace
        
        # Get mesh points
        mesh_points = np.vstack((mesh.p[0, :], mesh.p[1, :])).T
        
        # Evaluate ground truth at mesh points
        ground_truth_values = self.ground_truth(mesh_points)
        
        # Create a simple solution-like object
        solution = SimpleNamespace()
        solution.mesh = mesh
        solution.value = ground_truth_values
        
        return solution
    
    def _visualize_and_log_ground_truth(self, vis_dir: Optional[str]) -> None:
        """
        Create and log a visualization of the ground truth function on a fine mesh.
        
        Parameters
        ----------
        vis_dir : Optional[str]
            Directory to save visualizations, or None if not saving
        """
        # Create a fine mesh for ground truth visualization
        fine_mesh = create_uniform_mesh(
            nx=self.config.fine_mesh_resolution, 
            ny=self.config.fine_mesh_resolution
        )
        
        # Create ground truth solution object
        ground_truth_solution = self._create_ground_truth_solution_object(fine_mesh)
        
        # Create visualization
        if vis_dir is not None:
            gt_vis_file = os.path.join(vis_dir, "ground_truth_fine_mesh.png")
            self.visualize_mesh_solution(
                fine_mesh,
                ground_truth_solution,
                output_file=gt_vis_file,
                title="Ground Truth on Fine Mesh",
                colorbar_label="Ground Truth Value"
            )
            print(f"Ground truth visualization saved to: {gt_vis_file}")
            
            # Log to wandb if enabled (use step 0 which will be the same as initial solution)
            if self.config.use_wandb and self.wandb:
                self.wandb.log({"ground_truth_visualization": self.wandb.Image(gt_vis_file)}, step=0)
                print("Ground truth visualization logged to wandb")
        else:
            print("Ground truth visualization skipped (no visualization directory)")
    
    def visualize_mesh_solution(
        self,
        mesh: MeshTri, 
        solution: Any,  # Solution from PDESolver.solve
        output_file: Optional[str] = None,
        title: str = "Solution on Mesh",
        show: bool = False,
        colorbar_label: str = "Value",
        figsize: tuple = (10, 8)
    ) -> None:
        """
        Visualize a solution on a mesh.
        
        Parameters
        ----------
        mesh : MeshTri
            The mesh
        solution : Any
            Solution from PDESolver.solve
        output_file : Optional[str], default=None
            Path to save the visualization. If None, the image is not saved
        title : str, default="Solution on Mesh"
            Title for the visualization
        show : bool, default=False
            Whether to display the visualization
        colorbar_label : str, default="Value"
            Label for the colorbar
        figsize : tuple, default=(10, 8)
            Figure size
        """
        plt.figure(figsize=figsize)
        
        # Extract solution values
        u = solution.value
        
        # Get mesh points and triangulation
        x = mesh.p[0, :]
        y = mesh.p[1, :]
        triangles = mesh.t.T  # Transpose to get the right shape for plt.tricontourf
        
        # Create filled contour plot
        plt.tricontourf(x, y, triangles, u, cmap='viridis', levels=50)
        
        # Add colorbar
        cbar = plt.colorbar()
        cbar.set_label(colorbar_label, fontsize=12)
        
        # Show mesh edges
        plt.triplot(x, y, triangles, 'k-', alpha=0.3, linewidth=0.5)
        
        # Add labels and title (only if title is provided)
        plt.xlabel('x', fontsize=12)
        plt.ylabel('y', fontsize=12)
        if title:
            plt.title(title, fontsize=14)
        
        # Set aspect ratio to equal
        plt.axis('equal')
        plt.tight_layout()
        
        # Save if requested
        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches='tight')
            
        # Show or close
        if show:
            plt.show()
        else:
            plt.close()

    def run(self, mesh: MeshTri) -> Dict[str, Any]:
        """
        Run the optimization process.
        
        Parameters
        ----------
        mesh : MeshTri
            Initial mesh
            
        Returns
        -------
        Dict[str, Any]
            Optimization results
        """
        # Initialize
        params = self.initialize(mesh)
        
        # Create output directory for visualizations
        vis_dir = None
        if self.config.use_wandb and self.wandb:
            vis_dir = os.path.join(os.getcwd(), f"visualizations_{self.wandb.run.id}")
            os.makedirs(vis_dir, exist_ok=True)
        elif self.config.use_wandb:  # wandb is enabled but self.wandb might be None
            # Create a generic visualization directory
            vis_dir = os.path.join(os.getcwd(), "visualizations")
            os.makedirs(vis_dir, exist_ok=True)
        else:
            # Create a basic visualization directory even when wandb is disabled
            # to ensure ground truth visualization is always available
            vis_dir = os.path.join(os.getcwd(), "visualizations")
            os.makedirs(vis_dir, exist_ok=True)
        
        # Initial loss
        initial_loss = self.compute_loss(params)
        loss_history = [initial_loss]
        
        # Initialize histories for additional metrics
        gradient_norm_history = []
        mesh_mse_history = [] 
        
        # Create and log ground truth visualization first
        self._visualize_and_log_ground_truth(vis_dir)
        
        # Create initial visualization
        if vis_dir is not None:
            initial_mesh_iter = self._params_to_mesh(params, self.original_mesh)
            initial_basis = Basis(initial_mesh_iter, ElementTriP1())
            initial_solution = self.solver.solve(initial_basis, self.bc)
            
            vis_file = os.path.join(vis_dir, f"mesh_solution_iter_{0:04d}.png")
            self.visualize_mesh_solution(
                initial_mesh_iter,
                initial_solution,
                output_file=vis_file,
                title=""  # No title for wandb logging
            )
            self.visualization_files.append(vis_file)
            
            # Log to wandb
            if self.config.use_wandb and self.wandb:
                self.wandb.log({"mesh_visualization": self.wandb.Image(vis_file)}, step=0)
        
        # Log initial metrics
        if self.config.use_wandb:
            metrics = {"loss": initial_loss}
            mesh_mse = self.compute_mesh_mse(params)
            metrics["mesh_mse"] = mesh_mse
            mesh_mse_history.append(mesh_mse)
            
            # Initialize shrinkage metrics for MeshAwareEstimator
            estimator = self.get_estimator()
            if hasattr(estimator, 'current_iteration_shrinkages'):
                metrics["shrinkage_count"] = 0
                metrics["total_shrinkage_count"] = 0
                metrics["avg_shrinkage_per_iteration"] = 0.0
                metrics["max_shrinkage_per_iteration"] = 0
            
            # Initialize rejection metrics for RejectionMethodEstimator
            if hasattr(estimator, 'current_iteration_rejections'):
                metrics["rejection_count"] = 0
                metrics["total_rejection_count"] = 0
                metrics["avg_rejection_per_iteration"] = 0.0
                metrics["max_rejection_per_iteration"] = 0
            
            self.wandb.log(metrics, step=0)
        
        # Optimization loop
        for i in range(self.config.n_iterations):
            estimator = self.get_estimator()
            gradient = estimator(params)
            gradient_norm = np.linalg.norm(gradient)
            gradient_norm_history.append(gradient_norm)
            
            # Track shrinkage information for MeshAwareEstimator
            shrinkage_count = 0
            if hasattr(estimator, 'current_iteration_shrinkages'):
                shrinkage_count = estimator.current_iteration_shrinkages
            
            # Track rejection information for RejectionMethodEstimator
            rejection_count = 0
            if hasattr(estimator, 'current_iteration_rejections'):
                rejection_count = estimator.current_iteration_rejections
            
            # Perform optimization step
            params = self.step(params)
            
            # Record gradient update shrinkage information
            self.update_shrinkage_history.append(self.current_iteration_update_shrinkages)
            
            # Compute and record loss
            loss = self.compute_loss(params)
            loss_history.append(loss)
            
            # Create visualization every 100 steps
            if vis_dir is not None and (i + 1) % 100 == 0:
                current_mesh = self._params_to_mesh(params, self.original_mesh)
                current_basis = Basis(current_mesh, ElementTriP1())
                current_solution = self.solver.solve(current_basis, self.bc)
                
                vis_file = os.path.join(vis_dir, f"mesh_solution_iter_{i+1:04d}.png")
                self.visualize_mesh_solution(
                    current_mesh,
                    current_solution,
                    output_file=vis_file,
                    title=""  # No title for wandb logging
                )
                self.visualization_files.append(vis_file)
                
                # Log to wandb
                if self.config.use_wandb and self.wandb:
                    self.wandb.log({"mesh_visualization": self.wandb.Image(vis_file)}, step=i+1)
            
            # Log metrics
            if self.config.use_wandb and self.wandb:
                metrics = {"loss": loss, "iteration": i+1}
            
                metrics["gradient_norm"] = gradient_norm
                mesh_mse = self.compute_mesh_mse(params)
                metrics["mesh_mse"] = mesh_mse
                mesh_mse_history.append(mesh_mse)
                
                # Add shrinkage information - always record for MeshAwareEstimator
                if hasattr(estimator, 'current_iteration_shrinkages'):
                    # For MeshAwareEstimator, always record shrinkage metrics
                    metrics["shrinkage_count"] = shrinkage_count
                    if hasattr(estimator, 'shrinkage_count'):
                        metrics["total_shrinkage_count"] = estimator.shrinkage_count
                    # Also record shrinkage history if available
                    if hasattr(estimator, 'shrinkage_history') and estimator.shrinkage_history:
                        metrics["avg_shrinkage_per_iteration"] = np.mean(estimator.shrinkage_history)
                        metrics["max_shrinkage_per_iteration"] = np.max(estimator.shrinkage_history)
                
                # Add rejection information for RejectionMethodEstimator
                if hasattr(estimator, 'current_iteration_rejections'):
                    # For RejectionMethodEstimator, record rejection metrics
                    metrics["rejection_count"] = estimator.current_iteration_rejections
                    if hasattr(estimator, 'rejection_count'):
                        metrics["total_rejection_count"] = estimator.rejection_count
                    # Also record rejection history if available
                    if hasattr(estimator, 'rejection_history') and estimator.rejection_history:
                        metrics["avg_rejection_per_iteration"] = np.mean(estimator.rejection_history)
                        metrics["max_rejection_per_iteration"] = np.max(estimator.rejection_history)
                
                # Add gradient update shrinkage information for mesh-aware updates
                if self.config.mesh_aware_update and self.config.estimator_type.lower() in ["mesh_aware", "rejection_method"]:
                    metrics["update_shrinkage_count"] = self.current_iteration_update_shrinkages
                    metrics["total_update_shrinkage_count"] = self.update_shrinkage_count
                    if self.update_shrinkage_history:
                        metrics["avg_update_shrinkage_per_iteration"] = np.mean(self.update_shrinkage_history)
                        metrics["max_update_shrinkage_per_iteration"] = np.max(self.update_shrinkage_history)
                
                self.wandb.log(metrics, step=i+1)
            
            # Print progress
            print(f"Iteration {i+1}/{self.config.n_iterations}, Loss: {loss:.6f}")
            print(f"Gradient Norm: {gradient_norm:.6f}")
            if mesh_mse_history:  # Check if list is not empty
                print(f"Mesh MSE: {mesh_mse_history[-1]:.6f}")
            if shrinkage_count > 0:
                print(f"Gradient estimation shrinkages this iteration: {shrinkage_count}")
            if rejection_count > 0:
                print(f"Gradient estimation rejections this iteration: {rejection_count}")
            if self.current_iteration_update_shrinkages > 0:
                print(f"Gradient update shrinkages this iteration: {self.current_iteration_update_shrinkages}")
        
        # Create final mesh
        final_mesh = self._params_to_mesh(params, self.original_mesh)
        
        # Create final visualization
        if vis_dir is not None:
            final_basis = Basis(final_mesh, ElementTriP1())
            final_solution = self.solver.solve(final_basis, self.bc)
            
            vis_file = os.path.join(vis_dir, f"mesh_solution_iter_{self.config.n_iterations:04d}.png")
            self.visualize_mesh_solution(
                final_mesh,
                final_solution,
                output_file=vis_file,
                title=""  # No title for wandb logging
            )
            self.visualization_files.append(vis_file)
            
            # Log final visualization to wandb
            if self.config.use_wandb and self.wandb:
                self.wandb.log({"mesh_visualization": self.wandb.Image(vis_file)}, step=self.config.n_iterations)
        
        # Generate GIF/video from visualizations
        if len(self.visualization_files) > 1:
            gif_path = self._create_visualization_gif(vis_dir)
            if gif_path and self.config.use_wandb and self.wandb:
                # Log GIF to wandb
                self.wandb.log({"mesh_evolution_gif": self.wandb.Video(gif_path)})
        
        # Finalize wandb run if enabled
        if self.config.use_wandb and self.wandb:
            self.wandb.finish()
        
        # Prepare results dict
        results = {
            'params': params,
            'initial_mesh': mesh,
            'final_mesh': final_mesh,
            'loss_history': loss_history,
            'visualization_files': self.visualization_files
        }
        
        results['gradient_norm_history'] = gradient_norm_history
        results['mesh_mse_history'] = mesh_mse_history
        
        # Add shrinkage information if available
        estimator = self.get_estimator()
        if hasattr(estimator, 'shrinkage_history'):
            results['shrinkage_history'] = estimator.shrinkage_history
            results['total_shrinkage_count'] = estimator.shrinkage_count
        
        # Add rejection information if available
        if hasattr(estimator, 'rejection_history'):
            results['rejection_history'] = estimator.rejection_history
            results['total_rejection_count'] = estimator.rejection_count
        
        # Add gradient update shrinkage information if available
        if self.update_shrinkage_history:
            results['update_shrinkage_history'] = self.update_shrinkage_history
            results['total_update_shrinkage_count'] = self.update_shrinkage_count
            
        return results 

    def _create_visualization_gif(self, vis_dir: str) -> Optional[str]:
        """
        Create a GIF from the saved visualization files.
        
        Parameters
        ----------
        vis_dir : str
            Directory containing visualization files
            
        Returns
        -------
        Optional[str]
            Path to the created GIF file, or None if creation failed
        """
        if not self.visualization_files:
            return None
            
        try:
            # Sort files by iteration number to ensure correct order
            sorted_files = sorted(self.visualization_files, 
                                key=lambda x: int(x.split('_iter_')[1].split('.')[0]))
            
            # Read images and create GIF
            images = []
            for file_path in sorted_files:
                if os.path.exists(file_path):
                    img = Image.open(file_path)
                    images.append(img)
            
            if images:
                gif_path = os.path.join(vis_dir, "mesh_evolution.gif")
                # Save as GIF with reasonable duration (500ms per frame)
                images[0].save(
                    gif_path,
                    save_all=True,
                    append_images=images[1:],
                    duration=500,
                    loop=0
                )
                print(f"GIF created: {gif_path}")
                return gif_path
            
        except Exception as e:
            print(f"Warning: Could not create GIF: {e}")
            return None
        
        return None 