#!/usr/bin/env python3
"""
Example script for running mesh optimization.

This script demonstrates how to use the mesh_opt package 
to create and optimize a mesh for solving a PDE.
"""

import os
import sys
import argparse
from pprint import pprint

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import mesh_opt


def main():
    """Run the mesh optimization example."""
    parser = argparse.ArgumentParser(description="Mesh Optimization Example")
    parser.add_argument(
        "--mesh", 
        type=str, 
        default=None, 
        help="Path to initial mesh file (optional)"
    )
    parser.add_argument(
        "--output", 
        type=str, 
        default="./output", 
        help="Output directory"
    )
    parser.add_argument(
        "--iterations", 
        type=int, 
        default=20, 
        help="Number of optimization iterations"
    )
    parser.add_argument(
        "--step-size", 
        type=float, 
        default=0.01, 
        help="Step size for gradient descent"
    )
    parser.add_argument(
        "--regularization", 
        type=float, 
        default=0.0001, 
        help="Regularization weight"
    )
    
    args = parser.parse_args()
    
    print("Mesh Optimization Example")
    print("------------------------")
    print(f"Initial mesh: {'Generated automatically' if args.mesh is None else args.mesh}")
    print(f"Output directory: {args.output}")
    print(f"Iterations: {args.iterations}")
    print(f"Step size: {args.step_size}")
    print(f"Regularization weight: {args.regularization}")
    print()
    
    # Configure the trainer
    config = {
        'n_iterations': args.iterations,
        'step_size': args.step_size,
        'regularization_weight': args.regularization,
    }
    
    # Run optimization
    result = mesh_opt.optimize_mesh(
        initial_mesh_file=args.mesh,
        output_dir=args.output,
        config=config
    )
    
    # Print some statistics
    initial_loss = result['loss_history'][0]
    final_loss = result['loss_history'][-1]
    improvement = (initial_loss - final_loss) / initial_loss * 100
    
    print("\nOptimization Results:")
    print(f"Initial MSE: {initial_loss:.6f}")
    print(f"Final MSE: {final_loss:.6f}")
    print(f"Improvement: {improvement:.2f}%")
    
    # Print paths to output files
    print("\nOutput Files:")
    print(f"MSE History Plot: {os.path.join(args.output, 'mse_history.png')}")
    print(f"Initial Solution: {os.path.join(args.output, 'initial_solution.png')}")
    print(f"Optimized Solution: {os.path.join(args.output, 'optimized_solution.png')}")
    print(f"Optimized Mesh: {os.path.join(args.output, 'optimized_mesh.npz')}")


if __name__ == "__main__":
    # python examples/demo.py --iterations 100 --step-size 0.01 --output ./results
    main() 