import numpy as np
import matplotlib.pyplot as plt

def evaluate_model_on_plane(plane, X_test, analytical_sol, diffusion_solver_elm, t_eval_pts):
    """
    Evaluates both the ground truth and the model solution on a given 2D plane.

    Args:
        plane (int): Index of the starting dimension for the 2D plane.
        X_test (np.ndarray): Input samples for evaluation.
        analytical_sol (callable): Function to compute the ground truth solution.
        diffusion_solver_elm (object): Model object with an `evaluate` method.
        t_eval_pts (np.ndarray): Time evaluation points (1D row vector).

    Returns:
        tuple: (X_sampled, u_true_temp, u_model_temp, u_error_temp)
    """
    dim1 = plane
    X_sampled = X_test[0:-1:100, :]
    X_temp = np.zeros_like(X_sampled)
    X_temp[:, dim1] = X_sampled[:, dim1]

    u_true_temp = analytical_sol(X_temp, t=t_eval_pts).T
    u_model_temp = diffusion_solver_elm.evaluate(x_eval=X_temp, t_eval=t_eval_pts.T).T
    u_error_temp = np.abs(u_true_temp[:, dim1] - u_model_temp[:, dim1])

    return X_sampled, u_true_temp, u_model_temp, u_error_temp

def plot_plane_results(X_sampled, t_eval_pts, u_true, u_model, u_error, plane_idx, fig_idx, fontsize=10):
    """
    Plots the true solution, model solution, and error on a 2D plane.

    Args:
        X_sampled (np.ndarray): Sampled input points.
        t_eval_pts (np.ndarray): Time evaluation points.
        u_true (np.ndarray): Ground truth solution values.
        u_model (np.ndarray): Model prediction values.
        u_error (np.ndarray): Absolute error values.
        plane_idx (int): Index of the plane dimension.
        fig_idx (int): Figure index for saving the file.
        fontsize (int): Font size for plot annotations.
    """
    fig = plt.figure(figsize=(7, 2))

    def _plot_subplot(position, data, title):
        plt.subplot(1, 3, position)
        scatter = plt.scatter(x=X_sampled[:, plane_idx], y=t_eval_pts, c=data, cmap='jet')
        cbar = plt.colorbar(scatter)
        cbar.ax.tick_params(labelsize=fontsize)
        cbar.ax.yaxis.get_offset_text().set_fontsize(fontsize)
        plt.title(title, fontsize=fontsize)
        plt.xlabel(f"$x_{{{plane_idx + 1}}}$", fontsize=fontsize)
        plt.ylabel("t", fontsize=fontsize)
        plt.tick_params(labelsize=fontsize)

    _plot_subplot(1, u_true[:, plane_idx], "Ground Truth")
    _plot_subplot(2, u_model[:, plane_idx], "Frozen-PINN-swim")
    _plot_subplot(3, u_error, "Absolute Error")

    plt.tight_layout()
    plt.show()
    fig.savefig(f"spatial_snap_{fig_idx}_t1.pdf")

def visualize_2d_planes(X_test, analytical_sol, diffusion_solver_elm, planes=[0, 20, 99]):
    """
    Main function to visualize true vs predicted solutions across selected 2D planes.

    Args:
        X_test (np.ndarray): Input test samples.
        analytical_sol (callable): Ground truth solution function.
        diffusion_solver_elm (object): Model with an `evaluate` method.
        planes (list): List of starting dimension indices for 2D planes.
    """
    t_eval_pts = np.linspace(0, 1, 100).reshape(1, -1)

    for idx, plane in enumerate(planes, start=1):
        X_sampled, u_true, u_model, u_error = evaluate_model_on_plane(
            plane, X_test, analytical_sol, diffusion_solver_elm, t_eval_pts
        )
        print(f"Shapes -> True: {u_true.shape}, Model: {u_model.shape}")
        plot_plane_results(X_sampled, t_eval_pts, u_true, u_model, u_error, plane, idx)


def evaluate_solution_on_plane(plane, X_test, analytical_sol, diffusion_solver_elm, t_eval_pts, default_value=-0.5):
    """
    Evaluates ground truth and model prediction over a 2D plane by fixing other dimensions.

    Args:
        plane (int): Starting index of the 2D plane (uses plane and plane+1).
        X_test (np.ndarray): Test input data.
        analytical_sol (callable): Function returning true PDE solution.
        diffusion_solver_elm (object): Model with `.evaluate(x_eval, t_eval)` method.
        t_eval_pts (np.ndarray): Time points at which to evaluate solution.
        default_value (float): Value to fill in non-plane dimensions.

    Returns:
        tuple: (X_sampled, u_true_temp, u_model_temp, u_error_temp)
    """
    dim1 = plane
    X_sampled = X_test[::100, :]
    X_temp = default_value * np.ones_like(X_sampled)
    X_temp[:, dim1] = X_sampled[:, dim1]

    u_true_temp = analytical_sol(X_temp, t=t_eval_pts).T
    u_model_temp = diffusion_solver_elm.evaluate(x_eval=X_temp, t_eval=t_eval_pts.T).T
    u_error_temp = np.abs(u_true_temp[:, dim1] - u_model_temp[:, dim1])

    return X_sampled, u_true_temp, u_model_temp, u_error_temp

def plot_true_solution_planes(X_sampled_list, u_true_list, planes, t_eval_pts, fontsize=14, save_path="high_dim_pde_true.pdf"):
    """
    Plots ground truth PDE solution on multiple 2D planes.

    Args:
        X_sampled_list (list): List of sampled input arrays (one per plane).
        u_true_list (list): List of ground truth solution arrays (one per plane).
        planes (list): List of plane indices used for each plot.
        t_eval_pts (np.ndarray): Evaluation time points.
        fontsize (int): Font size for plot labels.
        save_path (str): Path to save the final plot.
    """
    fig = plt.figure(figsize=(7, 2))
    for j, (X_sampled, u_true, plane) in enumerate(zip(X_sampled_list, u_true_list, planes), start=1):
        plt.subplot(1, len(planes), j)
        plt.scatter(x=X_sampled[:, plane], y=t_eval_pts, c=u_true[:, plane], cmap='jet')
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=fontsize)
        cbar.ax.xaxis.get_offset_text().set_fontsize(fontsize)
        plt.xlabel(f"$x_{{{plane + 1}}}$", fontsize=fontsize)
        plt.ylabel("t", fontsize=fontsize)
        plt.tick_params(labelsize=fontsize)

    plt.tight_layout()
    plt.show()
    fig.savefig(save_path)

def visualize_true_solution(X_test, analytical_sol, diffusion_solver_elm, planes=[0, 99], fontsize=14):
    """
    Visualizes the ground truth PDE solution over specified planes.

    Args:
        X_test (np.ndarray): Test data for spatial dimensions.
        analytical_sol (callable): Analytical solution function.
        diffusion_solver_elm (object): Model used for error analysis.
        planes (list): List of starting indices for 2D planes.
        fontsize (int): Font size for plots.
    """
    t_eval_pts = np.linspace(0, 1, 100).reshape(1, -1)
    X_samples, u_true_vals = [], []

    for plane in planes:
        X_sampled, u_true, _, _ = evaluate_solution_on_plane(
            plane, X_test, analytical_sol, diffusion_solver_elm, t_eval_pts
        )
        print(f"Plane {plane}: True shape: {u_true.shape}")
        X_samples.append(X_sampled)
        u_true_vals.append(u_true)

    plot_true_solution_planes(X_samples, u_true_vals, planes, t_eval_pts, fontsize)
