import numpy as np
import matplotlib.pyplot as plt




def generate_labels(d):
    """
    Generate LaTeX-style axis labels for d-dimensional input.

    Args:
        d (int): Number of dimensions.

    Returns:
        tuple: Two lists of x and y axis labels.
    """
    labels = [f"$x_{{{i+1}}}$" for i in range(d)]
    return labels, labels


def compute_errors(X_sampled, row, col, t_snap, analytical_sol, model):
    """
    Computes the absolute error between the true solution and model prediction.

    Args:
        X_sampled (np.ndarray): Sampled input data of shape (N, d).
        row (int): Row dimension index to vary.
        col (int): Column dimension index to vary.
        t_snap (np.ndarray): Time snapshot of shape (1, 1).
        analytical_sol (callable): Ground truth function.
        model (object): Model with .evaluate() method.

    Returns:
        tuple: x-values, y-values, and error values for the scatter plot.
    """
    X_temp = np.zeros_like(X_sampled)
    X_temp[:, col] = X_sampled[:, col]
    X_temp[:, row] = X_sampled[:, row]

    u_true = analytical_sol(X_temp, t=t_snap).reshape(-1)
    u_model = model.evaluate(x_eval=X_temp, t_eval=t_snap).reshape(-1)
    error = np.abs(u_true - u_model)

    return X_sampled[:, col], X_sampled[:, row], error


def plot_error_grid(X_sampled, analytical_sol, model, t_snap, d=10, fontsize=16, figsize=(3.3, 3.3)):
    """
    Plots a d x d grid of absolute error scatter plots for each (row > col) pair.

    Args:
        X_sampled (np.ndarray): Sampled input data.
        analytical_sol (callable): Ground truth solution function.
        model (object): Trained model object with `.evaluate()`.
        t_snap (np.ndarray): Single time point of shape (1, 1).
        d (int): Dimensionality of input.
        fontsize (int): Font size for labels and ticks.
        figsize (tuple): Size of the entire figure.
    """
    fig, axes = plt.subplots(d, d, figsize=figsize, 
                             gridspec_kw={'wspace': 0., 'hspace': 0.})
    axes = axes.reshape(d, d)
    x_labels, y_labels = generate_labels(d)

    scatter_plots = []
    for row in range(d):
        for col in range(d):
            ax = axes[row, col]
            ax.set_xticks([])
            ax.set_yticks([])

            if row > col:
                x_vals, y_vals, errors = compute_errors(X_sampled, row, col, t_snap, analytical_sol, model)
                sc = ax.scatter(x_vals, y_vals, c=errors, cmap='jet', s=2)
                scatter_plots.append(sc)
            else:
                ax.axis('off')

            if row == d - 1:
                ax.set_xlabel(x_labels[col], fontsize=fontsize)
            if col == 0:
                ax.set_ylabel(y_labels[row], fontsize=fontsize)

    # Shared colorbar
    cbar_ax = fig.add_axes([0.9, 0.13, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(scatter_plots[0], cax=cbar_ax, orientation='vertical')
    cbar.formatter.set_useMathText(True)
    cbar.ax.tick_params(labelsize=fontsize)
    cbar.formatter.set_powerlimits((0, 0))
    cbar.ax.yaxis.set_offset_position('right')
    cbar.ax.yaxis.get_offset_text().set_fontsize(fontsize)

    plt.tight_layout()
    plt.show()
    fig.savefig('errors_slices.png', dpi=600)


def generate_fixed_input(X_template, x_vals, coordinate_index):
    """
    Generates input data with one varying coordinate while others are fixed to zero.

    Args:
        X_template (np.ndarray): Template array with same shape as test data.
        x_vals (np.ndarray): Values to assign to the selected coordinate.
        coordinate_index (int): Index of the coordinate to vary.

    Returns:
        np.ndarray: Modified input array for evaluation.
    """
    X_temp = np.zeros_like(X_template)
    X_temp[:, coordinate_index] = x_vals
    return X_temp


def evaluate_at_time(X_temp, t_val, analytical_sol, model):
    """
    Evaluates analytical and model solutions at a single time point.

    Args:
        X_temp (np.ndarray): Input array with varying coordinate.
        t_val (np.ndarray): Time value of shape (1, 1).
        analytical_sol (callable): Ground truth function.
        model (object): Trained model with `.evaluate()` method.

    Returns:
        tuple: (true solution, model prediction, absolute error), all np.ndarray of shape (1, len(x_vals))
    """
    u_true = analytical_sol(X_temp, t=t_val).reshape(-1, X_temp.shape[0])
    u_model = model.evaluate(x_eval=X_temp, t_eval=t_val).reshape(-1, X_temp.shape[0])
    u_error = np.abs(u_true - u_model)
    return u_true, u_model, u_error


def plot_time_slices(x_vals, solutions, coordinate, fontsize=10):
    """
    Plots true vs predicted solutions for three different time slices.

    Args:
        x_vals (np.ndarray): Values of the spatial coordinate.
        solutions (list): List of tuples (u_true, u_model, u_error) for each time point.
        coordinate (int): Index of the spatial coordinate being plotted.
        fontsize (int): Font size for labels, ticks, and titles.
    """
    fig, axes = plt.subplots(1, 3, figsize=(7, 2))
    times = [0.01, 0.5, 0.99]

    for i, (u_true, u_model, _) in enumerate(solutions):
        axes[i].plot(x_vals, u_true[0], label="True", color="orange")
        axes[i].plot(x_vals, u_model[0], label="Frozen-PINN-elm", color="green", linestyle='--')
        axes[i].set_title(f"t = {times[i]}", fontsize=fontsize)
        axes[i].set_xlabel(f"$x_{{{coordinate}}}$", fontsize=fontsize)
        axes[i].tick_params(labelsize=fontsize)
        axes[i].legend(fontsize=fontsize)
        if i == 0:
            axes[i].set_ylabel("y", fontsize=fontsize)

    plt.tight_layout()
    plt.show()


def plot_coordinate_comparisons(X_test, analytical_sol, model, coordinates, fontsize=10):
    """
    Controls the pipeline for comparing model predictions to ground truth at various time steps.

    Args:
        X_test (np.ndarray): Test data input array.
        analytical_sol (callable): Ground truth solution function.
        model (object): Trained model object.
        coordinates (list): List of coordinate indices to evaluate.
        fontsize (int): Font size for plots.
    """
    x_vals = np.linspace(-1, 1, 100)
    X_sampled = X_test[0:-1:100, :]

    for coord in coordinates:
        X_temp = generate_fixed_input(X_sampled, x_vals, coord)

        solutions = []
        for t_val in [0.01, 0.5, 0.99]:
            t_eval = np.array([[t_val]])
            u_true, u_model, u_error = evaluate_at_time(X_temp, t_eval, analytical_sol, model)
            solutions.append((u_true, u_model, u_error))

        plot_time_slices(x_vals, solutions, coordinate=coord, fontsize=fontsize)


def generate_sampled_data(X_test, plane):
    """
    Extracts and prepares sampled data for a specific plane (dimension).

    Args:
        X_test (np.ndarray): The full test input array.
        plane (int): The dimension (index) to sample along.

    Returns:
        np.ndarray: Sampled input data for the given plane.
    """
    X_sampled = X_test[0:-1:100, :]
    X_temp = np.zeros_like(X_sampled)
    X_temp[:, plane] = X_sampled[:, plane]
    return X_sampled, X_temp


def evaluate_solutions(X_temp, t_eval_pts, analytical_sol, model, plane, svd_on = True):
    """
    Evaluates both the analytical solution and model prediction.

    Args:
        X_temp (np.ndarray): The test inputs for evaluation.
        t_eval_pts (np.ndarray): Time evaluation points.
        analytical_sol (callable): Function to compute the ground truth.
        model (object): Trained model with an `.evaluate()` method.
        plane (int): The spatial dimension to extract.

    Returns:
        tuple of np.ndarray: Ground truth, model output, and absolute error.
    """
    u_true_temp = analytical_sol(X_temp, t=t_eval_pts).reshape(-1, 100)
    u_model_temp = model.evaluate(x_eval=X_temp, t_eval=t_eval_pts, svd_on=svd_on).reshape(-1, 100).T
    u_error_temp = np.abs(u_true_temp[:, plane] - u_model_temp[:, plane])
    return u_true_temp, u_model_temp, u_error_temp


def plot_comparison(X_sampled, t_eval_pts, u_true, u_model, u_error, plane, dim1, fontsize, fig_index):
    """
    Plots ground truth, model output, and absolute error for a given plane.

    Args:
        X_sampled (np.ndarray): Sampled spatial coordinates.
        t_eval_pts (np.ndarray): Time evaluation points.
        u_true (np.ndarray): Ground truth solution.
        u_model (np.ndarray): Model output.
        u_error (np.ndarray): Absolute error.
        plane (int): Spatial dimension being visualized.
        dim1 (int): Spatial dimension index for labeling.
        fontsize (int): Font size for all plot text.
        fig_index (int): Index for figure saving.
    """
    fig = plt.figure(figsize=(7, 2))

    for i, (data, title) in enumerate(zip(
        [u_true[:, plane], u_model[:, plane], u_error],
        ["Ground Truth", "Frozen-PINN-elm", "Absolute Error"]
    )):
        plt.subplot(1, 3, i + 1)
        scatter = plt.scatter(x=X_sampled[:, plane], y=t_eval_pts, c=data, cmap='jet')
        cbar = plt.colorbar(scatter)
        cbar.ax.tick_params(labelsize=fontsize)
        cbar.ax.yaxis.get_offset_text().set_fontsize(fontsize)
        plt.title(title, fontsize=fontsize)
        plt.xlabel(f"$x_{{{dim1 + 1}}}$", fontsize=fontsize)
        plt.ylabel("t", fontsize=fontsize)
        plt.tick_params(labelsize=fontsize)

    plt.tight_layout()
    plt.show()
    fig.savefig(f"spatial_snap_{fig_index}_t1.pdf")


def plot_plane_comparisons(
    X_test, analytical_sol, model, planes, d, fontsize=10, svd_on=True
):
    """
    Orchestrates plotting for multiple 2D slices (planes) of a high-dimensional domain.

    Args:
        X_test (np.ndarray): Full test input array.
        analytical_sol (callable): Function for computing the true solution.
        model (object): Trained model with `.evaluate()` method.
        planes (list): List of dimension indices to visualize.
        d (int): Total number of spatial dimensions.
        fontsize (int): Font size for all labels and ticks.
    """
    t_eval_pts = np.linspace(0, 1, 100).reshape(-1, 1)
    for j, plane in enumerate(planes, start=1):
        dim1 = plane
        X_sampled, X_temp = generate_sampled_data(X_test, plane)
        u_true, u_model, u_error = evaluate_solutions(
            X_temp, t_eval_pts, analytical_sol, model, plane, svd_on=svd_on
        )
        plot_comparison(X_sampled, t_eval_pts, u_true, u_model, u_error, plane, dim1, fontsize, j)
