import logging
import glob
import tarfile
from pathlib import Path
import random
from math import inf
from typing import Union

import matplotlib.colors as colors
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable

import numpy as np

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.optim import Optimizer

def save_model_checkpoint(epoch: int, device: str, model: nn.Module, optimizer: Optimizer, output_dir: Union[str, Path], prefix = None, suffix = None, overwrite=False):
    # consider abstracting the prefix and suffix string building into a separate function
    if overwrite == False:
        filepath = f'{output_dir}/{prefix + "_" if prefix is not None else ""}model_checkpoint_{epoch}{"_" + suffix if suffix is not None else ""}.pt'
    else:
        filepath = f'{output_dir}/{prefix + "_" if prefix is not None else ""}model_checkpoint{"_" + suffix if suffix is not None else ""}.pt'
    torch.save({'epoch': epoch,
        'device': device,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()},
        filepath)

def load_model(checkpoint_path: str, model: nn.Module):
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])

    return model

# scales a float tensor to [0, 255] uint8 tensor
def tensor_to_grayscale(x: torch.Tensor, min: float, max: float):
    xnew = (x - min)/(max - min)
    xnew = (xnew * 255.0).type(torch.uint8)
    return xnew

def min_max_norm(x, min: float, max: float):
    xnorm = (x - min) / (max - min)
    return xnorm

def get_mean_std(dataset: Dataset):
    perm_mean, perm_std, head_mean, head_std = 0.0, 0.0, 0.0, 0.0
    with torch.no_grad():
        for item in dataset:
            perm_mean += item[0].mean()
            head_mean += item[1].mean()
            perm_std += item[0].std()
            head_std += item[1].std()
    return perm_mean.item()/len(dataset), perm_std.item()/len(dataset), head_mean.item()/len(dataset), head_std.item()/len(dataset)

def get_min_max(dataset: Dataset):
    perm_min, perm_max, head_min, head_max = inf, -inf, inf, -inf
    with torch.no_grad():
        for item in dataset:
            if item[0].min() < perm_min:
                perm_min = item[0].min()
            if item[0].max() > perm_max:
                perm_max = item[0].max()
            if item[1].min() < head_min:
                head_min = item[1].min()
            if item[1].max() > head_max:
                head_max = item[1].max()
    return perm_min, perm_max, head_min, head_max

def plot_arr(arr, path):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)
    # axs[0, 0].scatter(np.arange(100), arr)
    axs[0, 0].plot(arr)
    fig.savefig(path)
    plt.close()

def plot_lists(lists, title, labels, path):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)
    colors_list = list(colors.BASE_COLORS.keys())
    markers_list = ['o', '+', 'x', '*', '.', 'X']
    for i, ds in enumerate(lists):
        for data in ds:
            axs[0, 0].plot(data, label=labels[i], color=colors_list[i])
    axs[0, 0].set_title(title)
    plt.legend(loc='upper left')
    fig.savefig(path)
    plt.close()

def scatter_lists(lists, title, labels, path):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)
    x = np.arange(len(lists[0][0]))
    colors_list = list(colors.BASE_COLORS.keys())
    markers_list = ['.', '+', 'x', '*', ',', 'X']
    for i, ds in enumerate(lists):
        for data in ds:
            axs[0, 0].scatter(x, data, label=labels[i], marker=markers_list[i], color=colors_list[i])
    axs[0, 0].set_title(title)
    plt.legend(loc='upper left')
    fig.savefig(path)
    plt.close()

def plot_action_trajs(id_data, ood_data, title, labels, path):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    fig, axs = plt.subplots(nrows=5, ncols=2, squeeze=False, figsize=(10, 30))
    x = np.arange(len(id_data[0]))
    colors_list = list(colors.BASE_COLORS.keys())
    markers_list = ['.', '+', 'x', '*', ',', 'X']
    lists = [id_data, ood_data]
    for i, ds in enumerate(lists):
        for j, data in enumerate(ds):
            axs[j, i].scatter(x, data, label=labels[i], color=colors_list[i], marker=".")
    # axs[0, 0].set_title(title)
    plt.title(title)
    plt.legend(loc='upper left')
    fig.savefig(path)
    plt.close()

def plot_mse_sim_calls(mses, surr_mses, sim_calls, path):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    colors = ["blue", "green", "orange"]
    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)
    axs[0,0].set_ylabel("# Sim Calls", color=colors[2])
    lns3 = axs[0, 0].plot(sim_calls, color=colors[2], label="# Sim Calls")
    ax2 = axs[0, 0].twinx()
    ax2.set_xlabel('timestep')
    ax2.set_ylabel('MSE', color=colors[0])
    lns1 = ax2.plot(mses, color=colors[0], label="RL MSE")
    # ax2.annotate(mses[-1], xy=(99, mses[-1]), xytext=(105, mses[-1]))
    lns2 = ax2.plot(surr_mses, color=colors[1], label="Surr MSE")
    # ax2.annotate(surr_mses[-1], xy=(99, surr_mses[-1]), xytext=(105, surr_mses[-1]))
    lns = lns1 + lns2 + lns3
    labs = [l.get_label() for l in lns]
    axs[0, 0].legend(lns, labs, loc=0)
    fig.savefig(path)
    plt.close()

def save_np_img(np_arr: np.ndarray, path, min_max=None, cmap='viridis'):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)
    img = np_arr
    if min_max is None:
        img_plt = axs[0, 0].imshow(img, origin="lower", cmap=cmap)
    else:
        img_plt = axs[0, 0].imshow(img, origin="lower", vmin=min_max[0], vmax=min_max[1], cmap=cmap)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    divider = make_axes_locatable(axs[0, 0])
    cax = divider.append_axes("bottom", size="5%", pad=0.05)
    cbar = plt.colorbar(img_plt, cax=cax, orientation="horizontal")
    cbar.ax.tick_params(labelsize='xx-small') 
    fig.savefig(path)
    plt.close(fig)

def save_tensor_img(t: torch.Tensor, path, min_max=None, cmap='viridis'):
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)
    img = t.detach().cpu().numpy()
    if min_max is None:
        img_plt = axs[0, 0].imshow(img, origin="lower", cmap=cmap)
    else:
        img_plt = axs[0, 0].imshow(img, vmin=min_max[0], vmax=min_max[1], origin="lower", cmap=cmap)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    divider = make_axes_locatable(axs[0, 0])
    cax = divider.append_axes("bottom", size="5%", pad=0.05)
    cbar = plt.colorbar(img_plt, cax=cax, orientation="horizontal")
    cbar.ax.tick_params(labelsize='xx-small') 
    fig.savefig(path)
    plt.close(fig)

def dict_tensor_to_items(x: dict):
    return {k: v.item() for k, v in x.items()}

def save_traj_images_ns(gts, preds, errs, smoke_min, smoke_max, output_path):
    num_rows = 1
    num_cols = 3
    title_size = 7
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    cmap = "viridis"
    for t in range(gts.shape[0]):
        fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
        gt_img = gts[t].detach().cpu().numpy()
        pred_img = preds[t].detach().cpu().numpy()
        err_img = errs[t].detach().cpu().numpy()

        gt_plt = axs[0, 0].imshow(gt_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 0].set_title("Ground Truth", size=title_size)
        axs[0, 0].axis("off")
        divider = make_axes_locatable(axs[0, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        pred_plt = axs[0, 1].imshow(pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 1].set_title("Prediction", size=title_size)
        axs[0, 1].axis("off")
        divider = make_axes_locatable(axs[0, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        err_plt = axs[0, 2].imshow(err_img, origin="lower", vmin=errs.min(), vmax=errs.max(), cmap=cmap)
        axs[0, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 2].set_title(f"Abs Error (mean: {err_img.mean():.3f})", size=title_size)
        axs[0, 2].axis("off")
        divider = make_axes_locatable(axs[0, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        fig.savefig(output_path / f"grid-{t}.png")
        plt.close(fig)

def save_traj_images(perm, gts, preds, errs, cumulative_errs, abs_cumulative_errs, perm_min, perm_max, head_min, head_max, output_path):
    num_rows = 2
    num_cols = 3
    title_size = 7
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    cmap = "coolwarm"
    #perm_img = np.rot90(perm.detach().cpu().numpy().T)
    perm_img = perm.detach().cpu().numpy()
    for i in range(gts.shape[0]):
        fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
        # gt_img = np.rot90(gts[i].detach().cpu().numpy().T)
        # pred_img = np.rot90(preds[i].detach().cpu().numpy().T)
        # err_img = np.rot90(errs[i].detach().cpu().numpy().T)
        # cumulative_err_img = np.rot90(cumulative_errs[i].detach().cpu().numpy().T)
        # abs_cumulative_err_img = np.rot90(abs_cumulative_errs[i].detach().cpu().numpy().T)
        gt_img = gts[i].detach().cpu().numpy()
        pred_img = preds[i].detach().cpu().numpy()
        err_img = errs[i].detach().cpu().numpy()
        cumulative_err_img = cumulative_errs[i].detach().cpu().numpy()
        abs_cumulative_err_img = abs_cumulative_errs[i].detach().cpu().numpy()

        perm_plt = axs[0, 0].imshow(perm_img, origin="lower", vmin=perm_min, vmax=perm_max, cmap=cmap)
        axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 0].set_title("Permeability", size=title_size)
        axs[0, 0].axis("off")
        divider = make_axes_locatable(axs[0, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(perm_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        gt_plt = axs[0, 1].imshow(gt_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 1].set_title("Ground Truth", size=title_size)
        axs[0, 1].axis("off")
        divider = make_axes_locatable(axs[0, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        pred_plt = axs[0, 2].imshow(pred_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 2].set_title("Prediction", size=title_size)
        axs[0, 2].axis("off")
        divider = make_axes_locatable(axs[0, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        # err_plt = axs[1, 0].imshow(err_img, vmin=errs.min(), vmax=errs.max(), cmap=cmap)
        err_plt = axs[1, 0].imshow(err_img, origin="lower", norm=colors.TwoSlopeNorm(vcenter=0.0, vmin=errs.min(), vmax=errs.max()), cmap=cmap)
        axs[1, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 0].set_title("Error", size=title_size)
        axs[1, 0].axis("off")
        divider = make_axes_locatable(axs[1, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        cum_err_plt = axs[1, 1].imshow(cumulative_err_img, origin="lower", norm=colors.TwoSlopeNorm(vcenter=0.0, vmin=cumulative_errs.min(), vmax=cumulative_errs.max()), cmap=cmap)
        axs[1, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 1].set_title("Cumulative Err", size=title_size)
        axs[1, 1].axis("off")
        divider = make_axes_locatable(axs[1, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(cum_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        #abs_cum_err_plt = axs[1, 2].imshow(abs_cumulative_err_img, origin="lower",norm=colors.TwoSlopeNorm(vcenter=0.0, vmin=abs_cumulative_errs.min(), vmax=abs_cumulative_errs.max()), cmap=cmap)
        abs_cum_err_plt = axs[1, 2].imshow(abs_cumulative_err_img, origin="lower",norm=colors.TwoSlopeNorm(vcenter=0.0, vmax=abs_cumulative_errs.max()), cmap=cmap)
        axs[1, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 2].set_title("Abs Cumul Err", size=title_size)
        axs[1, 2].axis("off")
        divider = make_axes_locatable(axs[1, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(abs_cum_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        fig.savefig(output_path / f"grid-{i}.png")
        plt.close(fig)

def save_traj_resid_images(perm, gts, preds, errs, cumulative_errs, residuals, perm_min, perm_max, head_min, head_max, output_path):
    num_rows = 2
    num_cols = 3
    title_size = 7
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    cmap = "coolwarm"
    #perm_img = np.rot90(perm.detach().cpu().numpy().T)
    perm_img = perm.detach().cpu().numpy()
    for i in range(gts.shape[0]):
        fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
        # gt_img = np.rot90(gts[i].detach().cpu().numpy().T)
        # pred_img = np.rot90(preds[i].detach().cpu().numpy().T)
        # err_img = np.rot90(errs[i].detach().cpu().numpy().T)
        # cumulative_err_img = np.rot90(cumulative_errs[i].detach().cpu().numpy().T)
        # abs_cumulative_err_img = np.rot90(abs_cumulative_errs[i].detach().cpu().numpy().T)
        gt_img = gts[i].detach().cpu().numpy()
        pred_img = preds[i].detach().cpu().numpy()
        err_img = errs[i].detach().cpu().numpy()
        cumulative_err_img = cumulative_errs[i].detach().cpu().numpy()
        residuals_img = residuals[i].detach().cpu().numpy()

        perm_plt = axs[0, 0].imshow(perm_img, origin="lower", vmin=perm_min, vmax=perm_max, cmap=cmap)
        axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 0].set_title("Permeability", size=title_size)
        axs[0, 0].axis("off")
        divider = make_axes_locatable(axs[0, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(perm_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        gt_plt = axs[0, 1].imshow(gt_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 1].set_title("Ground Truth", size=title_size)
        axs[0, 1].axis("off")
        divider = make_axes_locatable(axs[0, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        pred_plt = axs[0, 2].imshow(pred_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 2].set_title("Prediction", size=title_size)
        axs[0, 2].axis("off")
        divider = make_axes_locatable(axs[0, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        # err_plt = axs[1, 0].imshow(err_img, vmin=errs.min(), vmax=errs.max(), cmap=cmap)
        err_plt = axs[1, 0].imshow(err_img, origin="lower", norm=colors.TwoSlopeNorm(vcenter=0.0, vmin=errs.min(), vmax=errs.max()), cmap=cmap)
        axs[1, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 0].set_title("Error", size=title_size)
        axs[1, 0].axis("off")
        divider = make_axes_locatable(axs[1, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        cum_err_plt = axs[1, 1].imshow(cumulative_err_img, origin="lower", norm=colors.TwoSlopeNorm(vcenter=0.0, vmin=cumulative_errs.min(), vmax=cumulative_errs.max()), cmap=cmap)
        axs[1, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 1].set_title("Cumulative Err", size=title_size)
        axs[1, 1].axis("off")
        divider = make_axes_locatable(axs[1, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(cum_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        # plot with explicit min and max
        # residual_plt = axs[1, 2].imshow(residuals_img, origin="lower", vmin=residuals.min(), vmax=residuals.max(), cmap=cmap)
        # plot with changing/relative min and max
        residual_plt = axs[1, 2].imshow(residuals_img, origin="lower", cmap=cmap)
        axs[1, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 2].set_title("Residual", size=title_size)
        axs[1, 2].axis("off")
        divider = make_axes_locatable(axs[1, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(residual_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 
        
        fig.savefig(output_path / f"grid-{i}.png")
        plt.close(fig)

def save_traj_ns_custom(gts, surr_preds, rl_preds, baseline_preds, smoke_min, smoke_max, output_path):
    num_rows = 2
    num_cols = 4
    title_size = 9
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    # cmap = "coolwarm"
    cmap = "viridis"
    for t in range(gts.shape[0]):
        fig = plt.figure(figsize=(8, 5), layout="constrained")
        gs = GridSpec(4, 4, figure=fig)

        ax1 = fig.add_subplot(gs[1:3, 0])
        ax2 = fig.add_subplot(gs[0:2, 1])
        ax3 = fig.add_subplot(gs[0:2, 2])
        ax4 = fig.add_subplot(gs[0:2, 3])
        ax5 = fig.add_subplot(gs[2:4, 1])
        ax6 = fig.add_subplot(gs[2:4, 2])
        ax7 = fig.add_subplot(gs[2:4, 3])

        gt_img = gts[t].detach().cpu().numpy()
        surr_pred_img = surr_preds[t].detach().cpu().numpy()
        rl_pred_img = rl_preds[t].detach().cpu().numpy()
        baseline_pred_img = baseline_preds[t].detach().cpu().numpy()

        gt_plt = ax1.imshow(gt_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        ax1.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax1.set_title("Ground Truth", size=title_size)
        ax1.axis("off")
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        surr_pred_plt = ax2.imshow(surr_pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        ax2.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax2.set_title("Surrogate Only Prediction", size=title_size)
        ax2.axis("off")
        divider = make_axes_locatable(ax2)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(surr_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        baseline_pred_plt = ax3.imshow(baseline_pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        ax3.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax3.set_title("Random Policy Prediction", size=title_size)
        ax3.axis("off")
        divider = make_axes_locatable(ax3)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(baseline_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small')

        rl_pred_plt = ax4.imshow(rl_pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        ax4.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax4.set_title("Hybrid Policy Prediction", size=title_size)
        ax4.axis("off")
        divider = make_axes_locatable(ax4)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(rl_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small')

        # compute absolute error maps
        surr_abs_err = np.abs(surr_pred_img-gt_img)
        rl_abs_err = np.abs(rl_pred_img-gt_img)
        baseline_abs_err = np.abs(baseline_pred_img-gt_img)
        err_min = min([np.min(surr_abs_err), np.min(rl_abs_err), np.min(baseline_abs_err)])
        err_max = max([np.max(surr_abs_err), np.max(rl_abs_err), np.max(baseline_abs_err)])

        surr_err_plt = ax5.imshow(surr_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        ax5.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax5.set_title(f"Surrogate Only Error {np.mean(surr_abs_err):.3f}", size=title_size)
        ax5.axis("off")
        divider = make_axes_locatable(ax5)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(surr_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        baseline_err_plt = ax6.imshow(baseline_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        ax6.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax6.set_title(f"Random Policy Error {np.mean(baseline_abs_err):.3f}", size=title_size)
        ax6.axis("off")
        divider = make_axes_locatable(ax6)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(baseline_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        rl_err_plt = ax7.imshow(rl_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        ax7.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        ax7.set_title(f"Hybrid Policy Error {np.mean(rl_abs_err):.3f}", size=title_size)
        ax7.axis("off")
        divider = make_axes_locatable(ax7)
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(rl_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        fig.savefig(output_path / f"grid-{t}.png")
        plt.close(fig)

def save_traj_ns(gts, surr_preds, rl_preds, baseline_preds, smoke_min, smoke_max, output_path):
    num_rows = 2
    num_cols = 4
    title_size = 7
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    # cmap = "coolwarm"
    cmap = "viridis"
    for t in range(gts.shape[0]):
        fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
        # fig = plt.figure(layout="constrained")

        gt_img = gts[t].detach().cpu().numpy()
        surr_pred_img = surr_preds[t].detach().cpu().numpy()
        rl_pred_img = rl_preds[t].detach().cpu().numpy()
        baseline_pred_img = baseline_preds[t].detach().cpu().numpy()

        gt_plt = axs[0, 0].imshow(gt_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 0].set_title("Ground Truth", size=title_size)
        axs[0, 0].axis("off")
        divider = make_axes_locatable(axs[0, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        surr_pred_plt = axs[0, 1].imshow(surr_pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 1].set_title("Surrogate Prediction", size=title_size)
        axs[0, 1].axis("off")
        divider = make_axes_locatable(axs[0, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(surr_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        rl_pred_plt = axs[0, 2].imshow(rl_pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        axs[0, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 2].set_title("Policy Prediction", size=title_size)
        axs[0, 2].axis("off")
        divider = make_axes_locatable(axs[0, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(rl_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small')

        baseline_pred_plt = axs[0, 3].imshow(baseline_pred_img, origin="lower", vmin=smoke_min, vmax=smoke_max, cmap=cmap)
        axs[0, 3].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 3].set_title("Baseline Prediction", size=title_size)
        axs[0, 3].axis("off")
        divider = make_axes_locatable(axs[0, 3])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(baseline_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small')

        # compute absolute error maps
        gt_abs_err = np.abs(gt_img-gt_img)
        surr_abs_err = np.abs(surr_pred_img-gt_img)
        rl_abs_err = np.abs(rl_pred_img-gt_img)
        baseline_abs_err = np.abs(baseline_pred_img-gt_img)
        err_min = min([np.min(surr_abs_err), np.min(rl_abs_err), np.min(baseline_abs_err)])
        err_max = max([np.max(surr_abs_err), np.max(rl_abs_err), np.max(baseline_abs_err)])

        gt_plt_2 = axs[1, 0].imshow(gt_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        axs[1, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 0].set_title("Ground Truth Abs Err", size=title_size)
        axs[1, 0].axis("off")
        divider = make_axes_locatable(axs[1, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt_2, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        surr_err_plt = axs[1, 1].imshow(surr_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        axs[1, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 1].set_title(f"Surrogate Abs Err {np.mean(surr_abs_err):.3f}", size=title_size)
        axs[1, 1].axis("off")
        divider = make_axes_locatable(axs[1, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(surr_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        rl_err_plt = axs[1, 2].imshow(rl_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        axs[1, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 2].set_title(f"Policy Abs Err {np.mean(rl_abs_err):.3f}", size=title_size)
        axs[1, 2].axis("off")
        divider = make_axes_locatable(axs[1, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(rl_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        baseline_err_plt = axs[1, 3].imshow(baseline_abs_err, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        axs[1, 3].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 3].set_title(f"Baseline Abs Err {np.mean(baseline_abs_err):.3f}", size=title_size)
        axs[1, 3].axis("off")
        divider = make_axes_locatable(axs[1, 3])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(baseline_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        fig.savefig(output_path / f"grid-{t}.png")
        plt.close(fig)

def save_traj_surr_rl_images(perm, gts, surr_preds, surr_errs, rl_preds, rl_errs, perm_min, perm_max, head_min, head_max, output_path):
    num_rows = 2
    num_cols = 3
    title_size = 7
    plt.rcParams["savefig.bbox"] = 'tight'
    plt.rcParams["figure.dpi"] = 300.0
    cmap = "coolwarm"
    # cmap = "viridis"
    perm_img = perm.detach().cpu().numpy()
    err_min = min(surr_errs.detach().cpu().numpy().min(), rl_errs.detach().cpu().numpy().min())
    err_max = max(surr_errs.detach().cpu().numpy().max(), rl_errs.detach().cpu().numpy().max())
    surr_err_min = surr_errs.detach().cpu().numpy().min()
    surr_err_max = surr_errs.detach().cpu().numpy().max()
    rl_err_min = rl_errs.detach().cpu().numpy().min()
    rl_err_max = rl_errs.detach().cpu().numpy().max()
    for i in range(gts.shape[0]):
        fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
        gt_img = gts[i].detach().cpu().numpy()
        surr_pred_img = surr_preds[i].detach().cpu().numpy()
        surr_err_img = surr_errs[i].detach().cpu().numpy()
        rl_pred_img = rl_preds[i].detach().cpu().numpy()
        rl_err_img = rl_errs[i].detach().cpu().numpy()

        gt_plt = axs[0, 0].imshow(gt_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 0].set_title("Ground Truth", size=title_size)
        axs[0, 0].axis("off")
        divider = make_axes_locatable(axs[0, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(gt_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        surr_pred_plt = axs[0, 1].imshow(surr_pred_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 1].set_title("Surrogate Prediction", size=title_size)
        axs[0, 1].axis("off")
        divider = make_axes_locatable(axs[0, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(surr_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        rl_pred_plt = axs[0, 2].imshow(rl_pred_img, origin="lower", vmin=head_min, vmax=head_max, cmap=cmap)
        axs[0, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, 2].set_title("Policy Prediction", size=title_size)
        axs[0, 2].axis("off")
        divider = make_axes_locatable(axs[0, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(rl_pred_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        perm_plt = axs[1, 0].imshow(perm_img, origin="lower", vmin=perm_min, vmax=perm_max, cmap=cmap)
        axs[1, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 0].set_title("Permeability", size=title_size)
        axs[1, 0].axis("off")
        divider = make_axes_locatable(axs[1, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(perm_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        surr_err_plt = axs[1, 1].imshow(surr_err_img, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        axs[1, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 1].set_title("Surrogate Abs Error", size=title_size)
        axs[1, 1].axis("off")
        divider = make_axes_locatable(axs[1, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(surr_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        rl_err_plt = axs[1, 2].imshow(rl_err_img, origin="lower", vmin=err_min, vmax=err_max, cmap=cmap)
        axs[1, 2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[1, 2].set_title("Policy Abs Error", size=title_size)
        axs[1, 2].axis("off")
        divider = make_axes_locatable(axs[1, 2])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        cbar = plt.colorbar(rl_err_plt, cax=cax, orientation="horizontal")
        cbar.ax.tick_params(labelsize='xx-small') 

        fig.savefig(output_path / f"grid-{i}.png")
        plt.close(fig)

def get_num_params(model, trainable=False):
    if trainable:
        pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        pytorch_total_params = sum(p.numel() for p in model.parameters())
    return pytorch_total_params

