import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy.typing as npt

def plot(x:npt.NDArray[np.float64], u:npt.NDArray[np.float64], 
        title:str, savefig:bool=False, fontsize:int=14, 
        timesteps:list= [0, 30, 60, 99], figsize:tuple=(6,2), cmap_offset:np.float64=0., figname:str='fig.pdf'):
    """ Make a scatter plot of the solution.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u (npt.NDArray[np.float64]): Solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers/time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fontsize = fontsize
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'

    # Calculate the common vmin and vmax for all datasets
    vmin = np.min(u[:, 0] + cmap_offset)
    vmax = np.max(u[:, 0] - cmap_offset) 
    marker_size = 0.4
    sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=u[timesteps[0]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    ax[0].tick_params(axis='both', labelsize=fontsize)
    sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=u[timesteps[1]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    ax[1].tick_params(axis='both', labelsize=fontsize)
    sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=u[timesteps[2]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    ax[2].tick_params(axis='both', labelsize=fontsize)
    sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=u[timesteps[3]], cmap=cmap, vmin=vmin, vmax=vmax, s=marker_size)
    ax[3].tick_params(axis='both', labelsize=fontsize)
    fig.supxlabel('X')
    fig.supylabel('Y')
    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)
    cbar_true_0 = fig.colorbar(sol_img0, ax= ax, location='right', aspect=8)
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    plt.tick_params(axis='both') 
    plt.suptitle(title, fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()


def plot_error(x:npt.NDArray[np.float64], u_true:npt.NDArray[np.float64], u_nn:npt.NDArray[np.float64], 
               title:str, timesteps:list= [0, 30, 60, 99], figsize:tuple=(8,3), fontsize:int=14, 
               savefig:bool=False, figname:str='fig.pdf'):
    """ Plot the absolute error.

    Args:
        x (npt.NDArray[np.float64]): Spatial data points (dimensions: n_points * n_dim)
        u_true (npt.NDArray[np.float64]): True solution (dimensions: n_timesteps * n_points * n_dim)
        u_nn (npt.NDArray[np.float64]): Neural network solution (dimensions: n_timesteps * n_points * n_dim)
        timesteps (list): Index numbers /time-steps at which the solution is visualized.
        title (str): title of the plot.
        savefig (bool, optional): Flag that decides to save the figure. Defaults to False.
        fontsize (int, optional): fontsize for plots. Defaults to 14.
        figsize (tuple, optional): figure size. Defaults to (6,2).
        cmap_offset (np.float, optional): offset for colorbar. Defaults to 0..
        figname (str, optional): Name of the saved figure. Defaults to 'fig.pdf'.
    """
    fig, ax = plt.subplots(1, 4, figsize=figsize, sharex=True, sharey=True, constrained_layout=True)
    cmap = 'jet'
    marker_size = 0.4
    fontsize=fontsize
    sol_img0 = ax[0].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[0]]- u_nn[timesteps[0]]), cmap=cmap, s=marker_size)
    ax[0].tick_params(axis='both', labelsize=fontsize)
    ax[0].set_xlabel('x' ,fontsize=fontsize)
    sol_img1 = ax[1].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[1]]- u_nn[timesteps[1]]), cmap=cmap, s=marker_size)
    ax[1].set_xlabel('x' ,fontsize=fontsize)
    ax[1].tick_params(axis='both', labelsize=fontsize)
    sol_img2 = ax[2].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[2]]- u_nn[timesteps[2]]), cmap=cmap, s=marker_size)
    ax[2].tick_params(axis='both', labelsize=fontsize)
    ax[2].set_xlabel('x' ,fontsize=fontsize)
    sol_img3 = ax[3].scatter(x[:,0], x[:, 1], c=np.abs(u_true[timesteps[3]]- u_nn[timesteps[3]]), cmap=cmap, s=marker_size)
    ax[3].tick_params(axis='both', labelsize=fontsize)
    ax[3].set_xlabel('x' ,fontsize=fontsize)

    ax[0].set_title('t = 0', fontsize=fontsize)
    ax[1].set_title('t = 0.3', fontsize=fontsize)
    ax[2].set_title('t = 0.6', fontsize=fontsize)
    ax[3].set_title('t = 1', fontsize=fontsize)

    cbar_true_0 = fig.colorbar(sol_img0, ax= ax[0], location='bottom', aspect=8)
    cbar_true_0.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_0.ax.tick_params(labelsize=fontsize)
    cbar_true_0.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_1 = fig.colorbar(sol_img1, ax= ax[1], location='bottom', aspect=8)
    cbar_true_1.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_1.ax.tick_params(labelsize=fontsize)
    cbar_true_1.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_2 = fig.colorbar(sol_img2, ax= ax[2], location='bottom', aspect=8)
    cbar_true_2.ax.ticklabel_format(style='scientific', axis='x', scilimits=(0, 0))
    cbar_true_2.ax.tick_params(labelsize=fontsize)
    cbar_true_2.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here

    cbar_true_3 = fig.colorbar(sol_img3, ax= ax[3], location='bottom', aspect=8)
    cbar_true_3.ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    cbar_true_3.ax.tick_params(labelsize=fontsize)
    cbar_true_3.ax.xaxis.get_offset_text().set_fontsize(fontsize)  # Set fontsize here
    #fig.supxlabel('X')
    fig.supylabel('Y', y=0.65)
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.suptitle(title, fontsize=fontsize)
    if savefig:
        plt.savefig(figname)
    plt.show()