import numpy as np
import matplotlib.pyplot as plt
from typing import List, Optional, Any, Union
from skfem import MeshTri


def plot_mse_history(
    loss_history: List[float], 
    output_file: Optional[str] = None,
    title: str = "MSE History",
    show: bool = False
) -> None:
    """
    Plot the MSE history from training.
    
    Parameters
    ----------
    loss_history : List[float]
        List of MSE values from training
    output_file : Optional[str], default=None
        Path to save the plot. If None, the plot is not saved
    title : str, default="MSE History"
        Title for the plot
    show : bool, default=False
        Whether to display the plot
    """
    plt.figure(figsize=(10, 6))
    plt.plot(loss_history, 'b-', linewidth=2)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Mean Squared Error', fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(True, alpha=0.3)
    
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    
    if show:
        plt.show()
    else:
        plt.close()


def export_optimized_mesh(
    mesh: MeshTri, 
    output_file: str
) -> None:
    """
    Export an optimized mesh to a file.
    
    Parameters
    ----------
    mesh : MeshTri
        The mesh to export
    output_file : str
        Path to save the mesh
    """
    np.savez(
        output_file, 
        points=mesh.p, 
        triangles=mesh.t
    )


def visualize_mesh_solution(
    mesh: MeshTri, 
    solution: Any,  # Solution from PDESolver.solve
    output_file: Optional[str] = None,
    title: str = "Solution on Mesh",
    show: bool = False,
    colorbar_label: str = "Value",
    figsize: tuple = (10, 8)
) -> None:
    """
    Visualize a solution on a mesh.
    
    Parameters
    ----------
    mesh : MeshTri
        The mesh
    solution : Any
        Solution from PDESolver.solve
    output_file : Optional[str], default=None
        Path to save the visualization. If None, the image is not saved
    title : str, default="Solution on Mesh"
        Title for the visualization
    show : bool, default=False
        Whether to display the visualization
    colorbar_label : str, default="Value"
        Label for the colorbar
    figsize : tuple, default=(10, 8)
        Figure size
    """
    plt.figure(figsize=figsize)
    
    # Extract solution values
    u = solution.value
    
    # Get mesh points and triangulation
    x = mesh.p[0, :]
    y = mesh.p[1, :]
    triangles = mesh.t.T  # Transpose to get the right shape for plt.tricontourf
    
    # Create filled contour plot
    plt.tricontourf(x, y, triangles, u, cmap='viridis', levels=50)
    
    # Add colorbar
    cbar = plt.colorbar()
    cbar.set_label(colorbar_label, fontsize=12)
    
    # Show mesh edges
    plt.triplot(x, y, triangles, 'k-', alpha=0.3, linewidth=0.5)
    
    # Add labels and title
    plt.xlabel('x', fontsize=12)
    plt.ylabel('y', fontsize=12)
    plt.title(title, fontsize=14)
    
    # Set aspect ratio to equal
    plt.axis('equal')
    plt.tight_layout()
    
    # Save if requested
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    
    # Show or close
    if show:
        plt.show()
    else:
        plt.close() 