from .mesh_utils import load_mesh, create_uniform_mesh, refine_mesh, mesh_to_coords, coords_to_mesh
from .pde_solver import WeakForm, DirichletBC, PDESolver
from .error_metrics import interpolate_solution, compute_mse
from .trainer import TrainerConfig, Trainer
from .reports import plot_mse_history, export_optimized_mesh, visualize_mesh_solution
from .estimator import EstimatorConfig, GradientEstimator, UniformEstimator, CenterGaussianEstimator, RandomPerturbEstimator, MeshAwareEstimator, RejectionMethodEstimator, ConvexCombinationEstimator, get_estimator

__all__ = [
    # mesh_utils
    'load_mesh', 'create_uniform_mesh', 'refine_mesh', 'mesh_to_coords', 'coords_to_mesh',
    
    # pde_solver
    'WeakForm', 'DirichletBC', 'PDESolver',
    
    # error_metrics
    'interpolate_solution', 'compute_mse',
    
    # trainer
    'TrainerConfig', 'Trainer',
    
    # reports
    'plot_mse_history', 'export_optimized_mesh', 'visualize_mesh_solution',
    
    # estimator
    'EstimatorConfig', 'GradientEstimator', 'UniformEstimator', 'CenterGaussianEstimator', 'RandomPerturbEstimator', 'MeshAwareEstimator', 'RejectionMethodEstimator', 'ConvexCombinationEstimator', 'get_estimator'
]

def optimize_mesh(
    initial_mesh_file=None, 
    output_dir="./output", 
    config=None
):
    """
    Run the mesh optimization pipeline from a file or generated mesh.
    
    This function provides a convenient entry point to run the full
    mesh optimization process from start to finish.
    
    Parameters
    ----------
    initial_mesh_file : str, optional
        Path to initial mesh file. If None, a uniform mesh is created.
    output_dir : str, default="./output"
        Directory to save outputs
    config : dict, optional
        Configuration dictionary with parameters for training.
        If None, default parameters are used.
        
    Returns
    -------
    dict
        Dictionary containing optimization results
    """
    import os
    import numpy as np
    from skfem import Basis, ElementTriP1
    from skfem.helpers import dot, grad
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Create or load initial mesh
    if initial_mesh_file is None:
        print("Creating uniform mesh...")
        mesh = create_uniform_mesh(nx=10, ny=10)
    else:
        print(f"Loading mesh from {initial_mesh_file}...")
        mesh = load_mesh(initial_mesh_file)
    
    # Create PDE problem
    print("Setting up PDE problem...")
    
    # Define Poisson problem with manufactured solution
    def poisson_form(u, v, w):
        return dot(grad(u), grad(v))
    
    def source_term(x):
        return 2 * np.pi**2 * np.sin(np.pi * x[0]) * np.sin(np.pi * x[1])
    
    def load_form(v, w):
        x = w.x
        return source_term(x) * v
    
    weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
    
    def zero_bc(x):
        return 0.0
    
    bc = DirichletBC(zero_bc)
    solver = PDESolver(weak_form)
    
    # Define ground truth (manufactured solution)
    def ground_truth(points):
        x, y = points[:, 0], points[:, 1]
        return np.sin(np.pi * x) * np.sin(np.pi * y)
    
    # Setup trainer
    print("Setting up trainer...")
    default_config = {
        'step_size': 0.01,
        'fd_radius': 0.05,
        'n_samples': 20,
        'n_iterations': 50,
        'regularization_weight': 0.0001,
        'eval_points_n': 100,
        'use_wandb': True,
        'wandb_project': 'mesh-opt',
        'wandb_entity': None
    }
    
    if config is not None:
        default_config.update(config)
    
    trainer_config = TrainerConfig(
        step_size=default_config['step_size'],
        fd_radius=default_config['fd_radius'],
        n_samples=default_config['n_samples'],
        n_iterations=default_config['n_iterations'],
        regularization_weight=default_config['regularization_weight'],
        eval_points_n=default_config['eval_points_n'],
        estimator_type=default_config.get('estimator_type', 'uniform'),
        gradient_estimator=default_config.get('gradient_estimator', 'standard'),
        random_radius_min=default_config.get('random_radius_min', 0.01),
        random_radius_max=default_config.get('random_radius_max', 0.1),
        gaussian_std=default_config.get('gaussian_std', 1.0),
        use_wandb=default_config.get('use_wandb', True),
        wandb_project=default_config.get('wandb_project', 'mesh-opt'),
        wandb_entity=default_config.get('wandb_entity', None)
    )
    
    trainer = Trainer(trainer_config, solver, bc, ground_truth)
    
    # Run optimization
    print("Running optimization...")
    result = trainer.run(mesh)
    
    # Generate reports
    print("Generating reports...")
    plot_mse_history(
        result['loss_history'], 
        output_file=os.path.join(output_dir, "mse_history.png")
    )
    
    export_optimized_mesh(
        result['final_mesh'], 
        output_file=os.path.join(output_dir, "optimized_mesh.npz")
    )
    
    # Note: Visualizations are now handled automatically by the trainer
    # and logged to wandb. The trainer creates its own visualization directory.
    
    print(f"Results saved to {output_dir}")
    
    return result
