import torch
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from functools import partial


def points_in_rectangle(points, left, right, top, bottom):
    x, y = points[:, 0], points[:, 1]
    within_x = (x >= left) & (x <= right)
    within_y = (y >= bottom) & (y <= top)
    return within_x & within_y
# hardcoded boundaries
# bigger setting
# points_in_bigger_rec = partial(points_in_rectangle, left=0.035, right=0.335, top=0.125, bottom=-0.125)
# points_in_smaller_rec = partial(points_in_rectangle, left=0.0, right=0.035, top=0.0046, bottom=-0.0046)
# smaller setting
# points_in_bigger_rec = partial(points_in_rectangle, left=0.035, right=0.3, top=0.07, bottom=-0.07)
# points_in_smaller_rec = partial(points_in_rectangle, left=0.0, right=0.035, top=0.0046, bottom=-0.0046)

# centered around (0,0)
# big_rec = np.array([[0.035, 0.3], [0.07, -0.07]])
# small_rec = np.array([[0.0, 0.035], [0.0046, -0.0046]])
# # big_rec = big_rec / big_rec.mean(axis=1)
# # small_rec = small_rec / small_rec.mean(axis=1)
# points_in_bigger_rec = partial(points_in_rectangle, left=big_rec[0,0], right=big_rec[0,1], top=big_rec[1,0], bottom=big_rec[1,1])
# points_in_smaller_rec = partial(points_in_rectangle, left=small_rec[0,0], right=small_rec[0,1], top=small_rec[1,0], bottom=small_rec[1,1])
# points_inside_fn = lambda x: points_in_bigger_rec(x) | points_in_smaller_rec(x)


# x_min = 0.
# x_max = 0.27
# y_min = 0
# y_max = 0.1

# shifted to positive rnage
# x_min = 0.
# x_max = 0.27
# y_min = -0.01
# y_max = 0.11

x_min = 0.035
x_max = 0.31
y_min = -0.06
y_max = 0.06

points_in_bigger_rec = partial(points_in_rectangle, left=x_min, right=x_max, bottom=y_min, top=y_max)
points_inside_fn = lambda x: points_in_bigger_rec(x)



def convert_to_numpy(list_of_tensors):
    return [t_.detach().cpu().numpy() if torch.is_tensor(t_) else t_ for t_ in list_of_tensors]


class Plotter_2d:
    def __init__(self):
       pass 
    
    def plot(self, mesh_centers, field, resolution=100):
        
        # x_min, x_max = np.min(mesh_centers[:, 0]), np.max(mesh_centers[:, 0])
        # y_min, y_max = np.min(mesh_centers[:, 1]), np.max(mesh_centers[:, 1])
        # print(f'{x_min=}, {x_max=}, {y_min=}, {y_max=}')

        x_min = 0.
        x_max = 0.335
        y_min = -0.125
        y_max = 0.125
        extent=(x_min, x_max, y_min, y_max)

        # Define the grid resolution
        grid_resolution = (resolution, int(resolution*1.2))
        factor = 2
        x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
        y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
        x_grid, y_grid = np.meshgrid(x_grid, y_grid)
        grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

        mask = points_inside_fn(grid_points)

        # Step 3: Interpolate values
        non_masked_points = grid_points[mask]
        interpolated_values = griddata(mesh_centers, field, non_masked_points, method='linear')

        # Create an empty array for the interpolated grid
        interpolated_grid = np.full(grid_points.shape[0], np.nan)
        interpolated_grid[mask] = interpolated_values
        interpolated_grid = interpolated_grid.reshape(x_grid.shape)

        # Step 4: Plot the image
        plt.figure()
        plt.imshow(interpolated_grid, extent=extent, origin='lower', cmap='viridis')
        plt.colorbar(label='norm_u')
        plt.show()
        
# without roll        
def plot_frames_2d_batched(mesh_centers, 
                           true_norms, 
                           pred_norms, 
                           diff_norms, 
                           true_w_noise_norms,
                           resolution, 
                           prediction_titles=None,
                           true_w_noise_titles=None,
                           plot_pts=None, 
                           logify=False, 
                           return_fig=False
                           ):
    batch_size = true_norms.shape[0]
    
    cols = 3
    if true_w_noise_norms is not None:
        cols += 1
    fig, axes = plt.subplots(batch_size, cols, figsize=(24, batch_size*6), squeeze=False)

    # Determine the minimum and maximum values across both datasets
    vmin = min(np.min(true_norms), np.min(pred_norms))
    vmax = max(np.max(true_norms), np.max(pred_norms))

    # Create a normalization instance
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    # Create a custom colormap that makes the masked areas transparent
    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    extent=(x_min, x_max, y_min, y_max)

    # Define the grid resolution
    grid_resolution = (resolution, int(resolution*1.2))
    factor = 2
    x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
    y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
    x_grid, y_grid = np.meshgrid(x_grid, y_grid)
    grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

    mask = points_inside_fn(grid_points)

    # Step 3: Interpolate values
    non_masked_points = grid_points[mask]
    def get_masked_interpolation(frame):
        interpolated_values = griddata(mesh_centers, frame, non_masked_points, method='linear')
        interpolated_grid = np.full(grid_points.shape[0], np.nan)
        interpolated_grid[mask] = interpolated_values
        res = interpolated_grid.reshape(x_grid.shape)
        return res
    
    for idx in range(batch_size):
        masked_true = get_masked_interpolation(true_norms[idx])
        masked_pred = get_masked_interpolation(pred_norms[idx])
        masked_diff = get_masked_interpolation(diff_norms[idx])
        masked_true_w_noise = None
        if true_w_noise_norms is not None:
            masked_true_w_noise = get_masked_interpolation(true_w_noise_norms[idx])
        
        # START PLOTTING

        cur_col = 0

        # Plot the true data with the custom colormap
        sc_true = axes[idx, cur_col].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
        axes[idx, cur_col].set_title('True')
        cur_col += 1
        
        # Plot the true data with noise if present
        if true_w_noise_norms is not None:
            sc_true = axes[idx, cur_col].imshow(masked_true_w_noise, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
            if true_w_noise_titles is None:
                axes[idx, cur_col].set_title('True')
            else:
                axes[idx, cur_col].set_title(true_w_noise_titles[idx])
            cur_col += 1
        
        # Plot the predicted data with the custom colormap
        sc_pred = axes[idx, cur_col].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
        if prediction_titles is None:
            axes[idx, cur_col].set_title('Predicted')
        else:
            axes[idx, cur_col].set_title(prediction_titles[idx])
        cur_col += 1

        # Apply logarithmic scaling if logify is True
        if logify:
            sc_true.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))
            sc_pred.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))


        # Create a single colorbar for the first two plots
        cbar = fig.colorbar(sc_pred, ax=axes[idx, :cur_col], orientation='vertical', fraction=0.015, pad=0.05)
        cbar.set_label('norm(u)')

        ## diff

        # Create a custom colormap for the difference plot
        cmap_diff = plt.cm.bwr
        cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

        # Plot the difference with the custom colormap
        sc_diff = axes[idx, cur_col].imshow(masked_diff, cmap=cmap_diff, origin='lower', extent=extent)
        axes[idx, cur_col].set_title('Difference')

        # Create a separate colorbar for the difference plot
        cbar_diff = fig.colorbar(sc_diff, ax=axes[idx, cur_col], orientation='vertical', fraction=0.015, pad=0.05)
        cbar_diff.set_label('Difference')
        
        if plot_pts is not None:
            plot_pts_i = plot_pts[idx]
            for i, ax in enumerate(axes.flatten()):
                ax.scatter(plot_pts_i[:,0], plot_pts_i[:,1], marker='x', alpha=0.2)

    if return_fig:
        return fig
    
    plt.show()



# without roll        
def plot_frames_2d(idx, mesh_centers, true_norm_list, pred_norm_list, diff_norm_list, resolution, logify=False, return_fig=False):
    fig, axes = plt.subplots(1, 3, figsize=(24, 6))

    # Determine the minimum and maximum values across both datasets
    vmin = min(np.min(true_norm_list[idx]), np.min(pred_norm_list[idx]))
    vmax = max(np.max(true_norm_list[idx]), np.max(pred_norm_list[idx]))

    # Create a normalization instance
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    # Create a custom colormap that makes the masked areas transparent
    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    x_min = 0.
    x_max = 0.335
    y_min = -0.125
    y_max = 0.125
    extent=(x_min, x_max, y_min, y_max)

    # Define the grid resolution
    grid_resolution = (resolution, int(resolution*1.2))
    factor = 2
    x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
    y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
    x_grid, y_grid = np.meshgrid(x_grid, y_grid)
    grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

    mask = points_inside_fn(grid_points)

    # Step 3: Interpolate values
    non_masked_points = grid_points[mask]
    def get_masked_interpolation(frame):
        interpolated_values = griddata(mesh_centers, frame, non_masked_points, method='linear')
        interpolated_grid = np.full(grid_points.shape[0], np.nan)
        interpolated_grid[mask] = interpolated_values
        res = interpolated_grid.reshape(x_grid.shape)
        return res
    
    masked_true = get_masked_interpolation(true_norm_list[idx])
    masked_pred = get_masked_interpolation(pred_norm_list[idx])
    masked_diff = get_masked_interpolation(diff_norm_list[idx])
    
    # START PLOTTING

    # Plot the true data with the custom colormap
    sc1 = axes[0].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[0].set_title('True')

    # Plot the predicted data with the custom colormap
    sc2 = axes[1].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[1].set_title('Predicted')

    # Apply logarithmic scaling if logify is True
    if logify:
        sc1.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))
        sc2.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))


    # Create a custom colormap for the difference plot
    cmap_diff = plt.cm.bwr
    cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    # Plot the difference with the custom colormap
    sc3 = axes[2].imshow(masked_diff, cmap=cmap_diff, origin='lower', extent=extent)
    axes[2].set_title('Difference')

    # Create a single colorbar for the first two plots
    cbar = fig.colorbar(sc2, ax=axes[:2], orientation='vertical', fraction=0.015, pad=0.05)
    cbar.set_label('norm(u)')

    # Create a separate colorbar for the difference plot
    cbar_diff = fig.colorbar(sc3, ax=axes[2], orientation='vertical', fraction=0.015, pad=0.05)
    cbar_diff.set_label('Difference')

    if return_fig:
        return fig
    
    plt.show()


# without roll        
def plot_frames_2d_with_naive_diff(idx, mesh_centers, true_list, pred_list, resolution, logify=False, return_fig=False):
    fig, axes = plt.subplots(1, 3, figsize=(24, 6))

    # Determine the minimum and maximum values across both datasets
    vmin = min(np.min(true_list[idx]), np.min(pred_list[idx]))
    vmax = max(np.max(true_list[idx]), np.max(pred_list[idx]))

    # Create a normalization instance
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    # Create a custom colormap that makes the masked areas transparent
    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    x_min = 0.
    x_max = 0.335
    y_min = -0.125
    y_max = 0.125
    extent=(x_min, x_max, y_min, y_max)

    # Define the grid resolution
    grid_resolution = (resolution, int(resolution*1.2))
    factor = 2
    x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
    y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
    x_grid, y_grid = np.meshgrid(x_grid, y_grid)
    grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

    mask = points_inside_fn(grid_points)

    # Step 3: Interpolate values
    non_masked_points = grid_points[mask]
    def get_masked_interpolation(frame):
        interpolated_values = griddata(mesh_centers, frame, non_masked_points, method='linear')
        interpolated_grid = np.full(grid_points.shape[0], np.nan)
        interpolated_grid[mask] = interpolated_values
        res = interpolated_grid.reshape(x_grid.shape)
        return res
    
    masked_true = get_masked_interpolation(true_list[idx])
    masked_pred = get_masked_interpolation(pred_list[idx])

    # START PLOTTING

    # Plot the true data with the custom colormap
    sc1 = axes[0].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[0].set_title('True')

    # Plot the predicted data with the custom colormap
    sc2 = axes[1].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[1].set_title('Predicted')

    # Apply logarithmic scaling if logify is True
    if logify:
        sc1.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))
        sc2.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))

    # Calculate the difference
    diff = masked_pred - masked_true

    # Create a custom colormap for the difference plot
    cmap_diff = plt.cm.bwr
    cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    # Apply the mask to the difference data
    masked_diff = np.ma.masked_where(mask.reshape(diff.shape) == 0, diff)

    # Plot the difference with the custom colormap
    sc3 = axes[2].imshow(masked_diff, cmap=cmap_diff, origin='lower', extent=extent)
    axes[2].set_title('Difference')

    # Create a single colorbar for the first two plots
    cbar = fig.colorbar(sc2, ax=axes[:2], orientation='vertical', fraction=0.015, pad=0.05)
    cbar.set_label('norm(u)')

    # Create a separate colorbar for the difference plot
    cbar_diff = fig.colorbar(sc3, ax=axes[2], orientation='vertical', fraction=0.015, pad=0.05)
    cbar_diff.set_label('Difference')

    if return_fig:
        return fig
    
    plt.show()
    
    
def plot_frames_2d_with_roll(idx, 
                             mesh_centers, 
                             true_list, 
                             pred_list, 
                             pred_diff_list, 
                             roll_list, 
                             roll_diff_list, 
                             resolution, 
                             logify=False, 
                             select_first=-1
                             ):
    x_min = 0.
    x_max = 0.335
    y_min = -0.125
    y_max = 0.125
    extent=(x_min, x_max, y_min, y_max)

    # Truncate lists if select_first > 0
    if select_first > 0:
        true_list = true_list[:select_first]
        pred_list = pred_list[:select_first]
        pred_diff_list = pred_diff_list[:select_first]
        roll_list = roll_list[:select_first]
        roll_diff_list = roll_diff_list[:select_first]

    # Determine the minimum and maximum values across both datasets
    vmin = min(np.min(true_list[idx]), np.min(pred_list[idx]), np.min(roll_list[idx]))
    vmax = max(np.max(true_list[idx]), np.max(pred_list[idx]), np.max(roll_list[idx]))

    fig, axes = plt.subplots(2, 3, figsize=(24, 12))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    # Define the grid resolution
    grid_resolution = (resolution, int(resolution*1.2))
    factor = 2
    x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
    y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
    x_grid, y_grid = np.meshgrid(x_grid, y_grid)
    grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

    mask = points_inside_fn(grid_points)

    # Step 3: Interpolate values
    non_masked_points = grid_points[mask]
    def get_masked_interpolation(frame):
        interpolated_values = griddata(mesh_centers, frame, non_masked_points, method='linear')
        interpolated_grid = np.full(grid_points.shape[0], np.nan)
        interpolated_grid[mask] = interpolated_values
        res = interpolated_grid.reshape(x_grid.shape)
        return res
    
    masked_true = get_masked_interpolation(true_list[idx])
    masked_pred = get_masked_interpolation(pred_list[idx])
    masked_roll = get_masked_interpolation(roll_list[idx])
    masked_pred_diff = get_masked_interpolation(pred_diff_list[idx])
    masked_roll_diff = get_masked_interpolation(roll_diff_list[idx])

    # START PLOTTING

    # Plot the true data with the custom colormap
    sc1 = axes[0,0].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[0,0].set_title('True')

    # Plot the predicted data with the custom colormap
    sc2 = axes[0,1].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[0,1].set_title('Next-step prediction')
    
    # Plot the predicted data with the custom colormap
    sc4 = axes[1,1].imshow(masked_roll, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm)
    axes[1,1].set_title(f'Rollout at {idx}')

    # Apply logarithmic scaling if logify is True
    if logify:
        sc1.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))
        sc2.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))
        sc4.set_norm(mcolors.LogNorm(vmin=max(vmin, 1e-10), vmax=vmax))

    # Create a custom colormap for the difference plot
    cmap_diff = plt.cm.bwr
    cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    # # Apply the mask to the difference data
    # masked_diff = np.ma.masked_where(mask.reshape(masked_pred_diff.shape) == 0, diff)

    # Plot the difference with the custom colormap
    sc3 = axes[0,2].imshow(masked_pred_diff, cmap=cmap_diff, origin='lower', extent=extent)
    axes[0,2].set_title('Difference next-step prediction')
    
    sc5 = axes[1,2].imshow(masked_roll_diff, cmap=cmap_diff, origin='lower', extent=extent)
    axes[1,2].set_title('Difference auto-regressive rollout')

    axes_flat = axes.flatten()


    # Create a single colorbar for the first two plots
    pred_axes = [0, 1, 4]
    cbar = fig.colorbar(sc2, ax=axes_flat[pred_axes], orientation='vertical', fraction=0.015, pad=0.05)
    cbar.set_label('norm(u)')

    # Create a separate colorbar for the difference plot
    diff_axes = [2, 5]
    cbar_diff = fig.colorbar(sc3, ax=axes_flat[diff_axes], orientation='vertical', fraction=0.015, pad=0.05)
    cbar_diff.set_label('Difference')

    axes[1, 0].axis('off') 

    plt.show()
    
    
    

def plot_means_and_variances(idx, 
                             mesh_centers, 
                             true_mean_list, 
                             pred_mean_list, 
                             true_var_list, 
                             pred_var_list, 
                             resolution, 
                             logify=False
                             ):
    true_mean_list = convert_to_numpy(true_mean_list)
    pred_mean_list = convert_to_numpy(pred_mean_list)
    true_var_list = convert_to_numpy(true_var_list)
    pred_var_list = convert_to_numpy(pred_var_list)
    
    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    axes = axes.ravel()
    
    # Determine the minimum and maximum values across both datasets
    vmin_mean = min(np.min(true_mean_list[idx]), np.min(pred_mean_list[idx]))
    vmax_mean = max(np.max(true_mean_list[idx]), np.max(pred_mean_list[idx]))
    vmin_var = min(np.min(true_var_list[idx]), np.min(pred_var_list[idx]))
    vmax_var = max(np.max(true_var_list[idx]), np.max(pred_var_list[idx]))

    # Create a normalization instance
    norm_mean = mcolors.Normalize(vmin=vmin_mean, vmax=vmax_mean)
    norm_var = mcolors.Normalize(vmin=vmin_var, vmax=vmax_var)

    # Create a custom colormap that makes the masked areas transparent
    cmap_viridis = plt.cm.viridis
    cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

    x_min = 0.
    x_max = 0.335
    y_min = -0.125
    y_max = 0.125
    extent=(x_min, x_max, y_min, y_max)

    # Define the grid resolution
    grid_resolution = (resolution, int(resolution*1.2))
    factor = 2
    x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
    y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
    x_grid, y_grid = np.meshgrid(x_grid, y_grid)
    grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

    mask = points_inside_fn(grid_points)

    # Step 3: Interpolate values
    non_masked_points = grid_points[mask]
    def get_masked_interpolation(frame):
        interpolated_values = griddata(mesh_centers, frame, non_masked_points, method='linear')
        interpolated_grid = np.full(grid_points.shape[0], np.nan)
        interpolated_grid[mask] = interpolated_values
        res = interpolated_grid.reshape(x_grid.shape)
        return res
    
    masked_true_mean = get_masked_interpolation(true_mean_list[idx])
    masked_pred_mean = get_masked_interpolation(pred_mean_list[idx])
    masked_true_var = get_masked_interpolation(true_var_list[idx])
    masked_pred_var = get_masked_interpolation(pred_var_list[idx])

    # START PLOTTING

    sc1 = axes[0].imshow(masked_true_mean, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm_mean)
    axes[0].set_title('True mean')

    sc2 = axes[1].imshow(masked_pred_mean, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm_mean)
    axes[1].set_title('Predicted mean')
    
    sc3 = axes[2].imshow(masked_true_var, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm_var)
    axes[2].set_title('True var')
    
    sc4 = axes[3].imshow(masked_pred_var, cmap=cmap_viridis, origin='lower', extent=extent, norm=norm_var)
    axes[3].set_title('Predicted var')

    # Apply logarithmic scaling if logify is True
    if logify:
        sc1.set_norm(mcolors.LogNorm(vmin=max(vmin_mean, 1e-10), vmax=vmax_mean))
        sc2.set_norm(mcolors.LogNorm(vmin=max(vmin_mean, 1e-10), vmax=vmax_mean))
        sc3.set_norm(mcolors.LogNorm(vmin=max(vmin_var, 1e-10), vmax=vmax_var))
        sc4.set_norm(mcolors.LogNorm(vmin=max(vmin_var, 1e-10), vmax=vmax_var))

    
    # Create a single colorbar for the first two plots
    cbar = fig.colorbar(sc2, ax=axes[[0,1]], orientation='vertical', fraction=0.015, pad=0.05)
    cbar.set_label('$\mu$')

    cbar_diff = fig.colorbar(sc4, ax=axes[[2,3]], orientation='vertical', fraction=0.015, pad=0.05)
    cbar_diff.set_label('$\sigma^2$')
    
    for ax in axes.flat:
        ax.axis('off')

    plt.show()


class FramePlotter:
    '''
    ATM frameplotter is very slow, as it needs to make the griddata computaiton.
    Takes c.a 1.5s per timestep
    '''

    def __init__(self,
                 mesh_centers,
                 true_list,
                 pred_list,
                 pred_diff_list,
                 roll_list,
                 roll_diff_list,
                 resolution,
                 logify=False,
                 select_first=-1):

        x_min = 0.
        x_max = 0.335
        y_min = -0.125
        y_max = 0.125
        self.extent=(x_min, x_max, y_min, y_max)
        self.logify = logify

        # Truncate lists if select_first > 0
        if select_first > 0:
            true_list = true_list[:select_first]
            pred_list = pred_list[:select_first]
            pred_diff_list = pred_diff_list[:select_first]
            roll_list = roll_list[:select_first]
            roll_diff_list = roll_diff_list[:select_first]

        self.true_list = true_list
        self.pred_list = pred_list
        self.pred_diff_list = pred_diff_list
        self.roll_list = roll_list
        self.roll_diff_list = roll_diff_list

        # Determine the minimum and maximum values across both datasets
        self.vmin = min(np.min(true_list), np.min(pred_list), np.min(roll_list))
        self.vmax = max(np.max(true_list), np.max(pred_list), np.max(roll_list))

        # Define the grid resolution
        grid_resolution = (resolution, int(resolution*1.2))
        factor = 2
        x_grid = np.linspace(x_min, x_max, int(grid_resolution[0]*factor))
        y_grid = np.linspace(y_min, y_max, int(grid_resolution[1]*factor))
        x_grid, y_grid = np.meshgrid(x_grid, y_grid)
        grid_points = np.column_stack((x_grid.ravel(), y_grid.ravel()))

        mask = points_inside_fn(grid_points)

        # Step 3: Interpolate values
        non_masked_points = grid_points[mask]

        def get_masked_interpolation(frame):
            interpolated_values = griddata(mesh_centers, frame, non_masked_points, method='linear')
            interpolated_grid = np.full(grid_points.shape[0], np.nan)
            interpolated_grid[mask] = interpolated_values
            res = interpolated_grid.reshape(x_grid.shape)
            return res

        # Cache the masked interpolations
        self.masked_true_list = [get_masked_interpolation(frame) for frame in true_list]
        self.masked_pred_list = [get_masked_interpolation(frame) for frame in pred_list]
        self.masked_roll_list = [get_masked_interpolation(frame) for frame in roll_list]
        self.masked_pred_diff_list = [get_masked_interpolation(frame) for frame in pred_diff_list]
        self.masked_roll_diff_list = [get_masked_interpolation(frame) for frame in roll_diff_list]


    def plot_frames(self, idx):
        fig, axes = plt.subplots(2, 3, figsize=(24, 6))
        norm = mcolors.Normalize(vmin=self.vmin, vmax=self.vmax)
        cmap_viridis = plt.cm.viridis
        cmap_viridis.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

        masked_true = self.masked_true_list[idx]
        masked_pred = self.masked_pred_list[idx]
        masked_roll = self.masked_roll_list[idx]
        masked_pred_diff = self.masked_pred_diff_list[idx]
        masked_roll_diff = self.masked_roll_diff_list[idx]

        # START PLOTTING

        # Plot the true data with the custom colormap
        sc1 = axes[0, 0].imshow(masked_true, cmap=cmap_viridis, origin='lower', extent=self.extent, norm=norm)
        axes[0, 0].set_title('True')

        # Plot the predicted data with the custom colormap
        sc2 = axes[0, 1].imshow(masked_pred, cmap=cmap_viridis, origin='lower', extent=self.extent, norm=norm)
        axes[0, 1].set_title('Next-step prediction')

        # Plot the predicted data with the custom colormap
        sc4 = axes[1, 1].imshow(masked_roll, cmap=cmap_viridis, origin='lower', extent=self.extent, norm=norm)
        axes[1, 1].set_title('Auto-regressive rollout')

        # Apply logarithmic scaling if logify is True
        if self.logify:
            sc1.set_norm(mcolors.LogNorm(vmin=max(self.vmin, 1e-10), vmax=self.vmax))
            sc2.set_norm(mcolors.LogNorm(vmin=max(self.vmin, 1e-10), vmax=self.vmax))
            sc4.set_norm(mcolors.LogNorm(vmin=max(self.vmin, 1e-10), vmax=self.vmax))

        # Create a custom colormap for the difference plot
        cmap_diff = plt.cm.bwr
        cmap_diff.set_bad(color='white', alpha=0)  # Set the color for masked values to transparent white

        # Plot the difference with the custom colormap
        sc3 = axes[0, 2].imshow(masked_pred_diff, cmap=cmap_diff, origin='lower', extent=self.extent)
        axes[0, 2].set_title('Difference next-step prediction')

        sc5 = axes[1, 2].imshow(masked_roll_diff, cmap=cmap_diff, origin='lower', extent=self.extent)
        axes[1, 2].set_title('Difference auto-regressive rollout')

        axes_flat = axes.flatten()

        # Create a single colorbar for the first two plots
        pred_axes = [0, 1, 4]
        cbar = fig.colorbar(sc2, ax=axes_flat[pred_axes], orientation='vertical', fraction=0.015, pad=0.05)
        cbar.set_label('norm(u)')

        # Create a separate colorbar for the difference plot
        diff_axes = [2, 5]
        cbar_diff = fig.colorbar(sc3, ax=axes_flat[diff_axes], orientation='vertical', fraction=0.015, pad=0.05)
        cbar_diff.set_label('Difference')

        axes[1, 0].axis('off')

        # Show the plot
        plt.show()