import argparse
import numpy as np
import torch
import os
import pyvista as pv
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib import cm
from matplotlib.colors import Normalize


def make_3d_plot(X_train,
                 y_train,
                 X_test,
                 f_test,
                 y_std,
                 savepath):


    # get gt pts

    pts_train = np.concatenate((X_train, y_train[..., None]), axis=-1)
    pts_test = np.concatenate((X_test, f_test[..., None]), axis=-1)

    # get pred interval, opacity --
    pts_up_CI = np.concatenate((X_test, f_test[..., None] + 2 * y_std[..., None] ,), axis=-1)
    pts_down_CI = np.concatenate((X_test, f_test[..., None] - 2 * y_std[..., None] ,), axis=-1)

    size_of_grid = int(np.sqrt(len(pts_up_CI[..., 0])))
    grid_mean = pv.StructuredGrid(pts_test[..., 0].reshape(size_of_grid, -1), pts_test[..., 1].reshape(size_of_grid, -1), pts_test[..., 2].reshape(size_of_grid, -1))
    grid_up_CI = pv.StructuredGrid(pts_up_CI[..., 0].reshape(size_of_grid, -1), pts_up_CI[..., 1].reshape(size_of_grid, -1), pts_up_CI[..., 2].reshape(size_of_grid, -1))
    grid_down_CI = pv.StructuredGrid(pts_down_CI[..., 0].reshape(size_of_grid, -1), pts_down_CI[..., 1].reshape(size_of_grid, -1), pts_down_CI[..., 2].reshape(size_of_grid, -1))


    pv.start_xvfb()
    a = pv.Plotter(off_screen=True, window_size=[1024, 1024])
    a.add_mesh(pv.PolyData(pts_train), color='red')
    a.add_mesh(grid_mean, color='green', opacity=0.3)
    a.add_mesh(grid_up_CI, color='blue', opacity=0.3)
    a.add_mesh(grid_down_CI, color='blue', opacity=0.3)

    actor = a.show_bounds(
        grid='front',
        location='outer',
        ticks='both',
        xtitle='age',
        ytitle='weight',
        ztitle='CSA [mm$^2$]',
        all_edges=True,
    )
    a.set_scale(zscale=0.5)
    a.export_html(savepath + '_comp.html')
    a.enable_zoom_style()
    a.close()
    return



def plot_regression_gt(
    X_train, y_train, X_test, f_test, y_std, savepath,
):
    sort_idx_train = np.argsort(X_train)
    X_train = X_train[sort_idx_train]
    y_train = y_train[sort_idx_train]

    sort_idx_test = np.argsort(X_test)
    X_test = X_test[sort_idx_test]
    f_test = f_test[sort_idx_test]
    y_std = y_std[sort_idx_test]

    fig, ax2 = plt.subplots(nrows=1, ncols=1, sharey=True, figsize=(8, 8))

    ax2.set_title("LA")
    ax2.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color="tab:orange")
    ax2.plot(X_test, f_test, label=r"$\mathbb{E}[f]$")
    ax2.fill_between(
        X_test,
        f_test - y_std * 2,
        f_test + y_std * 2,
        alpha=0.3,
        color="tab:blue",
        label=r"$2\sqrt{\mathbb{V}\,[y]}$",
    )
    
    # 基于真实数据确定轴的范围
    all_x_data = np.concatenate([X_train.flatten(), X_test.flatten()])
    all_y_data = np.concatenate([y_train.flatten(), f_test.flatten()])
    
    # 过滤掉 NaN 和 Inf 值，不让它们参与轴范围的决定
    valid_x_mask = np.isfinite(all_x_data)
    valid_y_mask = np.isfinite(all_y_data)
    
    valid_x_data = all_x_data[valid_x_mask]
    valid_y_data = all_y_data[valid_y_mask]
    
    # 检查是否有有效数据
    if len(valid_x_data) > 0 and len(valid_y_data) > 0:
        # 计算范围并添加一定的边距（增加到10%）
        x_margin = (valid_x_data.max() - valid_x_data.min()) * 0.1
        y_margin = (valid_y_data.max() - valid_y_data.min()) * 0.1
        
        x_min = valid_x_data.min() - x_margin
        x_max = valid_x_data.max() + x_margin
        y_min = valid_y_data.min() - y_margin
        y_max = valid_y_data.max() + y_margin
        
        # 设置轴范围
        ax2.set_xlim(x_min, x_max)
        ax2.set_ylim(y_min, y_max)
    else:
        # 如果没有有效数据，使用默认范围
        ax2.set_xlim(-1, 1)
        ax2.set_ylim(-1, 1)

    # 增大刻度数字的字体大小
    ax2.tick_params(axis='both', which='major', labelsize=18)
    ax2.tick_params(axis='both', which='minor', labelsize=18)
    
    plt.savefig(savepath, bbox_inches="tight", pad_inches=0.1)
    plt.close()





def plot_regression_all_samples(
    X_train: np.ndarray,
    y_train: np.ndarray,
    arr_grids: np.ndarray,
    f_mu: np.ndarray,
    high_bd_map: np.ndarray,
    low_bd_map: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
        dict_info: dict,
    savepath: str,
):
    '''

    plot all samples with CI in the training set and testing set

    Parameters
    ----------
    X_train: input features of training samples
    y_train: output of testing samples
            high_bd_map,
            low_bd_map,
    X_test: input features of testing samples
    y_test: output of testing samples
    savepath: output path

    Returns
    -------

    '''



    fig = plt.figure()
    ax = fig.add_subplot()
    from scipy.signal import savgol_filter

    ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")

    #f_mu = savgol_filter(f_mu, 10, 3)
    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    if high_bd_map is not None and low_bd_map is not None:
        # high_bd_map = savgol_filter(high_bd_map, 10, 3)
        # low_bd_map = savgol_filter(low_bd_map, 10, 3)

        ax.fill_between(
            arr_grids,
            low_bd_map,
            high_bd_map,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        ax.plot(arr_grids, low_bd_map, linewidth=0.3, color='black')
        ax.plot(arr_grids, high_bd_map, linewidth=0.3, color='black')

    # 基于真实数据确定轴的范围
    all_x_data = np.concatenate([X_train.flatten(), X_test.flatten()])
    all_y_data = np.concatenate([y_train.flatten(), y_test.flatten()])
    
    # 过滤掉 NaN 和 Inf 值，不让它们参与轴范围的决定
    valid_x_mask = np.isfinite(all_x_data)
    valid_y_mask = np.isfinite(all_y_data)
    
    valid_x_data = all_x_data[valid_x_mask]
    valid_y_data = all_y_data[valid_y_mask]
    
    # 检查是否有有效数据
    if len(valid_x_data) > 0 and len(valid_y_data) > 0:
        # 计算范围并添加一定的边距（增加到10%）
        x_margin = (valid_x_data.max() - valid_x_data.min()) * 0.05
        y_margin = (valid_y_data.max() - valid_y_data.min()) * 0.05
        
        x_min = valid_x_data.min() - x_margin
        x_max = valid_x_data.max() + x_margin
        y_min = valid_y_data.min() - y_margin
        y_max = valid_y_data.max() + y_margin
        
        # 设置轴范围
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
    else:
        # 如果没有有效数据，使用默认范围
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)

    plt.xlabel(dict_info["x_axis_name"], fontsize=24)
    plt.ylabel(dict_info["y_axis_name"], fontsize=24)
    ax.set_box_aspect(1)
    
    # 增大刻度数字的字体大小
    ax.tick_params(axis='both', which='major', labelsize=18)
    ax.tick_params(axis='both', which='minor', labelsize=18)
    
    plt.savefig(savepath, bbox_inches="tight", pad_inches = 0.1)
    plt.close()


def plot_regression_all_samples_with_gt(
        X_train: np.ndarray,
        y_train: np.ndarray,
        arr_grids: np.ndarray,
        f_mu: np.ndarray,
        high_bd_map: np.ndarray,
        low_bd_map: np.ndarray,
        f_mu_gt: np.ndarray,
        high_bd_gt_map: np.ndarray,
        low_bd_gt_map: np.ndarray,
        X_test: np.ndarray,
        y_test: np.ndarray,
        dict_info: dict,
        savepath: str,
):
    '''

    plot all samples with CI in the training set and testing set

    Parameters
    ----------
    X_train: input features of training samples
    y_train: output of testing samples
            high_bd_map,
            low_bd_map,
    X_test: input features of testing samples
    y_test: output of testing samples
    savepath: output path

    Returns
    -------

    '''

    fig = plt.figure()
    ax = fig.add_subplot()
    from scipy.signal import savgol_filter

    ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")

    # f_mu = savgol_filter(f_mu, 10, 3)
    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    if high_bd_map is not None and low_bd_map is not None:
        # high_bd_map = savgol_filter(high_bd_map, 10, 3)
        # low_bd_map = savgol_filter(low_bd_map, 10, 3)

        ax.fill_between(
            arr_grids,
            low_bd_map,
            high_bd_map,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        # ax.plot(arr_grids, low_bd_map, linewidth=0.3, color='black')
        # ax.plot(arr_grids, high_bd_map, linewidth=0.3, color='black')

        ax.plot(arr_grids, f_mu_gt, label=r"$\mathbb{E}[f]$", color='orange')
        ax.fill_between(
            arr_grids,
            low_bd_gt_map,
            high_bd_gt_map,
            alpha=0.3,
            color="tab:orange",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )

    # 基于真实数据确定轴的范围
    all_x_data = np.concatenate([X_train.flatten(), X_test.flatten()])
    all_y_data = np.concatenate([y_train.flatten(), y_test.flatten()])
    
    # 过滤掉 NaN 和 Inf 值，不让它们参与轴范围的决定
    valid_x_mask = np.isfinite(all_x_data)
    valid_y_mask = np.isfinite(all_y_data)
    
    valid_x_data = all_x_data[valid_x_mask]
    valid_y_data = all_y_data[valid_y_mask]
    
    # 检查是否有有效数据
    if len(valid_x_data) > 0 and len(valid_y_data) > 0:
        # 计算范围并添加一定的边距（增加到10%）
        x_margin = (valid_x_data.max() - valid_x_data.min()) * 0.05
        y_margin = (valid_y_data.max() - valid_y_data.min()) * 0.05
        
        x_min = valid_x_data.min() - x_margin
        x_max = valid_x_data.max() + x_margin
        y_min = valid_y_data.min() - y_margin
        y_max = valid_y_data.max() + y_margin
        
        # 设置轴范围
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
    else:
        # 如果没有有效数据，使用默认范围
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)

    plt.xlabel(dict_info["x_axis_name"], fontsize=24)
    plt.ylabel(dict_info["y_axis_name"], fontsize=24)
    ax.set_box_aspect(1)
    
    # 增大刻度数字的字体大小
    ax.tick_params(axis='both', which='major', labelsize=24)
    ax.tick_params(axis='both', which='minor', labelsize=24)
    
    plt.savefig(savepath, bbox_inches="tight", pad_inches=0.1)
    plt.close()







def plot_regression_all_samples_with_ood(
    X_train: np.ndarray,
    y_train: np.ndarray,
    arr_grids: np.ndarray,
    f_mu: np.ndarray,
    high_bd_map: np.ndarray,
    low_bd_map: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    X_ood_test: np.ndarray,
    y_ood_test: np.ndarray,
    dict_info: dict,
    savepath: str,
):
    '''

    plot all samples with CI in the training set and testing set

    Parameters
    ----------
    X_train: input features of training samples
    y_train: output of testing samples
            high_bd_map,
            low_bd_map,
    X_test: input features of testing samples
    y_test: output of testing samples
    savepath: output path

    Returns
    -------

    '''
    # sort_idx_train = np.argsort(X_train)
    # X_train = X_train[sort_idx_train]
    # y_train = y_train[sort_idx_train]
    # f_mu_train = f_mu_train[sort_idx_train]
    # f_std_train = f_std_train[sort_idx_train]
    #
    # sort_idx_test = np.argsort(X_test)
    # X_test = X_test[sort_idx_test]
    # y_test = y_test[sort_idx_test]


    fig = plt.figure()
    ax = fig.add_subplot()
    from scipy.signal import savgol_filter

    ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")
    ax.scatter(X_ood_test.flatten(), y_ood_test.flatten(), alpha=0.4, color="pink")

    #f_mu = savgol_filter(f_mu, 10, 3)
    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    if high_bd_map is not None and low_bd_map is not None:
        # high_bd_map = savgol_filter(high_bd_map, 10, 3)
        # low_bd_map = savgol_filter(low_bd_map, 10, 3)

        ax.fill_between(
            arr_grids,
            low_bd_map,
            high_bd_map,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        ax.plot(arr_grids, low_bd_map, linewidth=0.3, color='black')
        ax.plot(arr_grids, high_bd_map, linewidth=0.3, color='black')

    # 基于真实数据确定轴的范围
    all_x_data = np.concatenate([X_train.flatten(), X_test.flatten(), X_ood_test.flatten()])
    all_y_data = np.concatenate([y_train.flatten(), y_test.flatten(), y_ood_test.flatten()])
    
    # 过滤掉 NaN 和 Inf 值，不让它们参与轴范围的决定
    valid_x_mask = np.isfinite(all_x_data)
    valid_y_mask = np.isfinite(all_y_data)
    
    valid_x_data = all_x_data[valid_x_mask]
    valid_y_data = all_y_data[valid_y_mask]
    
    # 检查是否有有效数据
    if len(valid_x_data) > 0 and len(valid_y_data) > 0:
        # 计算范围并添加一定的边距（增加到10%）
        x_margin = (valid_x_data.max() - valid_x_data.min()) * 0.05
        y_margin = (valid_y_data.max() - valid_y_data.min()) * 0.05
        
        x_min = valid_x_data.min() - x_margin
        x_max = valid_x_data.max() + x_margin
        y_min = valid_y_data.min() - y_margin
        y_max = valid_y_data.max() + y_margin
        
        # 设置轴范围
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
    else:
        # 如果没有有效数据，使用默认范围
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)



            # f2 = arr_grids
            # ax.plot(arr_grids, f2, color='orange')

    plt.xlabel(dict_info["x_axis_name"], fontsize=24)
    plt.ylabel(dict_info["y_axis_name"], fontsize=24)
    ax.set_box_aspect(1)
    
    # 增大刻度数字的字体大小
    ax.tick_params(axis='both', which='major', labelsize=18)
    ax.tick_params(axis='both', which='minor', labelsize=18)
    
    plt.savefig(savepath, bbox_inches="tight", pad_inches = 0.1)
    plt.close()





def plot_regression_all_samples_for_esb(
    X_train,
    y_train,
    arr_grids,
    f_mu_list,
    X_test,
    y_test,
        dict_info,
    savepath,
):


    fig = plt.figure()
    ax = fig.add_subplot()

    ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")
    f_mu = np.mean(f_mu_list, axis=0)
    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    f_std = np.std(f_mu_list, axis=0)
    if f_std is not None:
        ax.fill_between(
            arr_grids,
            f_mu - f_std * 2,
            f_mu + f_std * 2,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        # ax.plot(arr_grids, f_mu - f_std * 2, linewidth=0.2, color='black')
        # ax.plot(arr_grids, f_mu + f_std * 2, linewidth=0.2, color='black')
        #
    for i in f_mu_list:
        ax.plot(arr_grids, i, linewidth=0.2, color='black')

    if dict_info["y_axis_name"] == "Y":
        if dict_info["x_axis_name"] == "X1":

            # x2 = torch.exp(x1 * 0.2)  # x in [2.5, 12.5]
            # f1 = torch.sin(x1 * 0.3) * 10
            # f2 = x2

            g1 = np.sin(arr_grids * 0.3) * 10 + np.exp(arr_grids * 0.2)
            ax.plot(arr_grids, g1, color='blue')

            # f1 = np.sin(arr_grids * 0.3) * 10
            # ax.plot(arr_grids, f1, color='orange')

        elif dict_info["x_axis_name"] == "X2":
            g2 = arr_grids + np.sin(np.log(arr_grids) * 5 * 0.3) * 10  # torch.sin(x1 * 0.3) * 10 + torch.exp(x1 * 0.2)
            ax.plot(arr_grids, g2, color='blue')

            # f2 = arr_grids
            # ax.plot(arr_grids, f2, color='orange')



    plt.xlabel(dict_info["x_axis_name"])
    plt.ylabel(dict_info["y_axis_name"])
    ax.set_box_aspect(1)
    plt.savefig(savepath, bbox_inches="tight", pad_inches = 0.1)
    plt.close()







def make_3d_plot(X_train,
                 y_train,
                 X_test,
                 f_test,
                 y_std,
                 savepath):


    # get gt pts

    pts_train = np.concatenate((X_train, y_train[..., None]), axis=-1)
    pts_test = np.concatenate((X_test, f_test[..., None]), axis=-1)

    # get pred interval, opacity --
    pts_up_CI = np.concatenate((X_test, f_test[..., None] + 2 * y_std[..., None] ,), axis=-1)
    pts_down_CI = np.concatenate((X_test, f_test[..., None] - 2 * y_std[..., None] ,), axis=-1)

    size_of_grid = int(np.sqrt(len(pts_up_CI[..., 0])))
    grid_mean = pv.StructuredGrid(pts_test[..., 0].reshape(size_of_grid, -1), pts_test[..., 1].reshape(size_of_grid, -1), pts_test[..., 2].reshape(size_of_grid, -1))
    grid_up_CI = pv.StructuredGrid(pts_up_CI[..., 0].reshape(size_of_grid, -1), pts_up_CI[..., 1].reshape(size_of_grid, -1), pts_up_CI[..., 2].reshape(size_of_grid, -1))
    grid_down_CI = pv.StructuredGrid(pts_down_CI[..., 0].reshape(size_of_grid, -1), pts_down_CI[..., 1].reshape(size_of_grid, -1), pts_down_CI[..., 2].reshape(size_of_grid, -1))


    pv.start_xvfb()
    a = pv.Plotter(off_screen=True, window_size=[1024, 1024])
    a.add_mesh(pv.PolyData(pts_train), color='red')
    a.add_mesh(grid_mean, color='green', opacity=0.3)
    a.add_mesh(grid_up_CI, color='blue', opacity=0.3)
    a.add_mesh(grid_down_CI, color='blue', opacity=0.3)

    actor = a.show_bounds(
        grid='front',
        location='outer',
        ticks='both',
        xtitle='age',
        ytitle='weight',
        ztitle='CSA [mm$^2$]',
        all_edges=True,
    )
    a.set_scale(zscale=0.5)
    a.export_html(savepath + '_comp.html')
    a.enable_zoom_style()
    a.close()
    return





def plot_toy_subnet_with_gt_local_ctb(
    X_train,
    y_train,
    arr_grids,
    f_mu,
    f_std,
    X_test,
    y_test,
        dict_info,
    savepath,
):

    fig = plt.figure()
    ax = fig.add_subplot()

    if dict_info["x_axis_name"] == "X1":

        # x2 = torch.exp(x1 * 0.2)  # x in [2.5, 12.5]
        # f1 = torch.sin(x1 * 0.3) * 10
        # f2 = x2

        # g1 = np.sin(arr_grids * 0.3) * 10 + np.exp(arr_grids * 0.2)
        # ax.plot(arr_grids, g1, color='blue')

        f1 = np.sin(arr_grids * 0.3) * 10
        ax.plot(arr_grids, f1, color='orange')
        intercept = f1.mean()
    elif dict_info["x_axis_name"] == "X2":
        # g2 = arr_grids + np.sin(np.log(arr_grids) * 5 * 0.3) * 10  # torch.sin(x1 * 0.3) * 10 + torch.exp(x1 * 0.2)
        # ax.plot(arr_grids, g2, color='blue')

        f2 = arr_grids
        ax.plot(arr_grids, f2, color='orange')
        intercept = f2.mean()


    from scipy.signal import savgol_filter
    f_mu = savgol_filter(f_mu, 10, 3) + intercept
    f_std = savgol_filter(f_std, 10, 3)


    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    if f_std is not None:
        ax.fill_between(
            arr_grids,
            f_mu - f_std * 2,
            f_mu + f_std * 2,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        ax.plot(arr_grids, f_mu - f_std * 2, linewidth=0.3, color='black')
        ax.plot(arr_grids, f_mu + f_std * 2, linewidth=0.3, color='black')



    plt.xlabel(dict_info["x_axis_name"])
    plt.ylabel(dict_info["y_axis_name"])
    ax.set_box_aspect(1)
    plt.savefig(savepath, bbox_inches="tight", pad_inches=0.1)
    plt.close()



def plot_toy_subnet_with_gt_local_ctb_namesb(
    X_train,
    y_train,
    arr_grids,
    f_mu_list,
    X_test,
    y_test,
        dict_info,
    savepath,
):


    fig = plt.figure()
    ax = fig.add_subplot()

    if dict_info["x_axis_name"] == "X1":
        # x2 = torch.exp(x1 * 0.2)  # x in [2.5, 12.5]
        # f1 = torch.sin(x1 * 0.3) * 10
        # f2 = x2
        f1 = np.sin(arr_grids * 0.3) * 10
        ax.plot(arr_grids, f1, color='orange')
        intercept = f1.mean()

    elif dict_info["x_axis_name"] == "X2":
        f2 = arr_grids
        ax.plot(arr_grids, f2, color='orange')
        intercept = f2.mean()


    # ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    # ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")
    f_mu = np.mean(f_mu_list, axis=0) + intercept
    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    f_std = np.std(f_mu_list, axis=0)
    if f_std is not None:
        ax.fill_between(
            arr_grids,
            f_mu - f_std * 2,
            f_mu + f_std * 2,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
    for i in f_mu_list:
        ax.plot(arr_grids, i + intercept, linewidth=0.2, color='black')


    plt.xlabel(dict_info["x_axis_name"])
    plt.ylabel(dict_info["y_axis_name"])
    ax.set_box_aspect(1)
    plt.savefig(savepath, bbox_inches="tight", pad_inches = 0.1)
    plt.close()





# def plot_regression_all_samples_with_bounds(
#     X_train,
#     y_train,
#     arr_grids,
#     f_mu,
#     f_up,
#     f_lo,
#     X_test,
#     y_test,
#     savepath,
# ):
#     '''
#
#     plot all samples with CI in the training set and testing set
#
#     Parameters
#     ----------
#     X_train: input features of training samples
#     y_train: output of testing samples
#     f_mu_train: predicted mean of training samples
#     f_std_train: predicted std of training samples
#     X_test: input features of testing samples
#     y_test: output of testing samples
#     savepath: output path
#
#     Returns
#     -------
#
#     '''
#     # sort_idx_train = np.argsort(X_train)
#     # X_train = X_train[sort_idx_train]
#     # y_train = y_train[sort_idx_train]
#     # f_mu_train = f_mu_train[sort_idx_train]
#     # f_std_train = f_std_train[sort_idx_train]
#     #
#     # sort_idx_test = np.argsort(X_test)
#     # X_test = X_test[sort_idx_test]
#     # y_test = y_test[sort_idx_test]
#
#     def extend_x_range(x):
#         return x
#
#     def extend_y_range(y):
#         return np.r_[y, y[np.newaxis, -1]]
#
#     new_x_vals = extend_x_range(arr_grids)
#     new_y_vals = extend_y_range(f_mu)
#
#     new_y_hi = extend_y_range(f_up)
#     new_y_lo = extend_y_range(f_lo)
#
#
#     fig, ax2 = plt.subplots(nrows=1, ncols=1, sharey=True)
#
#     ax2.set_title("LA")
#     ax2.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color="tab:orange")
#     ax2.scatter(X_test.flatten(), y_test.flatten(), alpha=0.3, color="tab:green")
#
#     ax2.plot(new_x_vals, new_y_vals, label=r"$\mathbb{E}[f]$")
#     ax2.fill_between(
#         new_x_vals,
#         new_y_hi,
#         new_y_lo,
#         alpha=0.3,
#         color="tab:blue",
#         label=r"$2\sqrt{\mathbb{V}\,[y]}$",
#     )
#     plt.savefig(savepath)
#     plt.close()




def plot_airway_shape_with_population(
    list_train_data,
    list_test_data,
    arr_grids,
    f_mu,
    f_std,

    dict_info,
    savepath,
):
    '''

    plot all samples with CI in the training set and testing set

    Parameters
    ----------
    X_train: input features of training samples
    y_train: output of testing samples
    f_mu: predicted mean of training samples
    f_std: predicted std of training samples
    X_test: input features of testing samples
    y_test: output of testing samples
    savepath: output path

    Returns
    -------

    '''

    from brokenaxes import brokenaxes
    fig = plt.figure()
    #ax = fig.add_subplot()
    bax = brokenaxes(ylims=((0, 500), (1700, 1750)))
    x = np.linspace(0, 1, 100)
    bax.plot(x, np.sin(10 * x), label='sin')
    bax.plot(x, np.cos(10 * x), label='cos')


    from scipy.signal import savgol_filter

    for ith_data in list_train_data:
        bax.plot(ith_data['input'][..., 0].cpu(), ith_data['output'].cpu(), linewidth=0.1, color="mediumspringgreen", alpha=0.4)
    for ith_data in list_test_data:
        bax.plot(ith_data['input'][..., 0].cpu(), ith_data['output'].cpu(), linewidth=0.1, color="mediumpurple", alpha=0.4)



    f_mu = savgol_filter(f_mu, 10, 3)
    f_std = savgol_filter(f_std, 10, 3)

    #
    # ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    # ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")

    bax.plot(arr_grids.squeeze(), f_mu, label=r"$\mathbb{E}[f]$", color='darkred')


    LB = f_mu - f_std * 2
    LB[LB<0] = 0

    if f_std is not None:
        bax.fill_between(
            arr_grids.squeeze(),
            LB,
            f_mu + f_std * 2,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        bax.plot(arr_grids.squeeze(), LB, linewidth=0.3, color='black')
        bax.plot(arr_grids.squeeze(), f_mu + f_std * 2, linewidth=0.3, color='black')


    #plt.xlabel('Depth ~ [0, 1]')
    #bax.ylabel(dict_info["y_axis_name"])
    #bax.set_box_aspect(1)
    plt.savefig(savepath, bbox_inches="tight", pad_inches = 0.1)
    plt.close()



def plot_3d_uncertainty(
    arr_x_train: np.ndarray,    # shape (N_train, 2)
    arr_y_train: np.ndarray,    # shape (N_train,)
    arr_x_grids: np.ndarray,    # shape (M, 2)  – a dense grid over (x1, x2)
    mh_map:     np.ndarray,     # shape (M,)    – model mean on the grid
    high_bd_map:     np.ndarray,     # shape (M,)    – model high_bd_map on the grid
    low_bd_map: np.ndarray,  # shape (M,)    – model low_bd_map on the grid
    arr_x_test: np.ndarray,     # shape (N_test, 2)
    arr_y_test: np.ndarray,     # shape (N_test,)
    savepath:   str
):
    """
    3D 回归可视化：训练点、测试点、均值面 & 置信区间面
    颜色 & 样式参考 plot_regression_all_samples
    """
    # ————————————
    # 1) 构造散点阵列
    pts_train = np.hstack([arr_x_train, arr_y_train])   # (N_train, 3)
    pts_test  = np.hstack([arr_x_test,  arr_y_test])   # (N_test,  3)

    # print(arr_x_grids.shape)
    # print(mh_map.shape)
    # print(high_bd_map.shape)
    # 2) 构造网格上的三种 Z
    pts_mean   = np.hstack([arr_x_grids,      mh_map[:, None]])
    pts_up_CI  = np.hstack([arr_x_grids, high_bd_map[:, None]])
    pts_down_CI= np.hstack([arr_x_grids,  low_bd_map[:, None]])

    # 3) 恢复成 StructuredGrid 所需的 (nx,ny) 形状
    M = pts_mean.shape[0]
    n = int(np.sqrt(M))
    assert n*n == M, "网格点数必须是完全平方数"
    Xm = pts_mean[:,0].reshape(n, n)
    Ym = pts_mean[:,1].reshape(n, n)
    Zm = pts_mean[:,2].reshape(n, n)

    Xu = pts_up_CI[:,0].reshape(n, n)
    Yu = pts_up_CI[:,1].reshape(n, n)
    Zu = pts_up_CI[:,2].reshape(n, n)

    Xd = pts_down_CI[:,0].reshape(n, n)
    Yd = pts_down_CI[:,1].reshape(n, n)
    Zd = pts_down_CI[:,2].reshape(n, n)

    grid_mean    = pv.StructuredGrid(Xm, Ym, Zm)
    grid_up_CI   = pv.StructuredGrid(Xu, Yu, Zu)
    grid_down_CI = pv.StructuredGrid(Xd, Yd, Zd)

    # ————————————
    # 4) 绘图
    pv.start_xvfb()
    plotter = pv.Plotter(off_screen=True, window_size=[1024, 1024])

    # 训练样本点
    plotter.add_mesh(
        pv.PolyData(pts_train),
        color="mediumspringgreen",
        render_points_as_spheres=True,
        point_size=8,
        name="train"
    )

    # 测试样本点
    plotter.add_mesh(
        pv.PolyData(pts_test),
        color="mediumpurple",
        render_points_as_spheres=True,
        point_size=8,
        name="test"
    )

    # 均值面
    plotter.add_mesh(
        grid_mean,
        color="blue",
        opacity=0.3,
        name="mean_surface"
    )

    # 上下 2σ 面
    plotter.add_mesh(
        grid_up_CI,
        color="grey",
        opacity=0.3,
        name="ci_upper"
    )
    plotter.add_mesh(
        grid_down_CI,
        color="grey",
        opacity=0.3,
        name="ci_lower"
    )

    # # 坐标轴标签 & 边框
    # plotter.show_bounds(
    #     grid='front',
    #     location='outer',
    #     ticks='both',
    #     xtitle='Age',
    #     ytitle='Weight',
    #     ztitle='CSA [mm$^2$]',
    #     all_edges=True
    # )


    # 强制各轴等比
    # 3) 计算三轴范围，用于等比缩放
    all_x = np.concatenate([pts_train[:,0], pts_test[:,0], pts_mean[:,0]])
    all_y = np.concatenate([pts_train[:,1], pts_test[:,1], pts_mean[:,1]])
    all_z = np.concatenate([pts_train[:,2], pts_test[:,2], pts_mean[:,2]])
    rx = all_x.max() - all_x.min()
    ry = all_y.max() - all_y.min()
    rz = all_z.max() - all_z.min()
    # 缩放因子取倒数，让每条轴“长度”都为 1
    sx, sy, sz = 1.0/rx, 1.0/ry, 1.0/rz

    plotter.set_scale(xscale=sx, yscale=sy, zscale=sz)
    plotter.enable_zoom_style()

    # 导出 HTML
    plotter.export_html(f"{savepath}.html")
    #plotter.screenshot(f"{savepath}.png")
    plotter.enable_zoom_style()
    plotter.close()




def plot_3din2d_uncertainty(
    arr_x_train: np.ndarray,
    arr_y_train: np.ndarray,
    arr_x_grids: np.ndarray,
    mh_map: np.ndarray,
    sh_map: np.ndarray,
    arr_x_test: np.ndarray,
    arr_y_test: np.ndarray,
    savepath: str,
        dict_info: dict,
    cmap_name: str = "turbo"
):
    # 1) 计算 IQR，去除训练和测试数据中的异常值
    all_y = np.concatenate([arr_y_train, arr_y_test])
    q1 = np.percentile(all_y, 5)
    q3 = np.percentile(all_y, 95)
    vmin = q1
    vmax = q3



    # 4) 渲染主图
    M = arr_x_grids.shape[0]
    n = int(np.sqrt(M))
    assert n * n == M, "arr_x_grids 需要是 (n², 2) 的网格"

    x1 = arr_x_grids[:, 0].reshape(n, n)
    x2 = arr_x_grids[:, 1].reshape(n, n)
    mean_pred = mh_map.reshape(n, n)
    std_pred  = sh_map.reshape(n, n)

    std_pred = np.clip(std_pred, a_min=np.quantile(std_pred, 0.05), a_max=np.quantile(std_pred, 0.95))
    std_norm = (std_pred - std_pred.min()) / (std_pred.max() - std_pred.min() + 1e-6)
    alpha_map = 1.0 - 0.7 * std_norm

    fig, ax = plt.subplots(figsize=(7, 6))
    cmap = cm.get_cmap(cmap_name)
    # Step 2: 定义颜色映射，clip=True 截断颜色但不过滤点
    norm = Normalize(vmin=vmin, vmax=vmax, clip=True)


    # 用 pcolormesh 填色
    quad = ax.pcolormesh(
        x1, x2, mean_pred,
        shading='auto',
        cmap=cmap,
        norm=norm
    )
    quad.set_alpha(alpha_map)

    # point_size = 40
    # sc_train = ax.scatter(arr_x_train[:, 0], arr_x_train[:, 1], c=arr_y_train, cmap=cmap, norm=norm, edgecolor='k', s=point_size, marker='o')
    # sc_test  = ax.scatter(arr_x_test[:, 0], arr_x_test[:, 1], c=arr_y_test,  cmap=cmap, norm=norm, edgecolor='#444', s=point_size, marker='o')

    point_size = 120
    sc_train = ax.scatter(
        arr_x_train[:, 0], arr_x_train[:, 1],
        c=arr_y_train, cmap=cmap, norm=norm,
        edgecolor='white', linewidth=1.5,  # ⬅️ 增加线宽
        s=point_size, marker='o'
    )
    sc_test = ax.scatter(
        arr_x_test[:, 0], arr_x_test[:, 1],
        c=arr_y_test, cmap=cmap, norm=norm,
        edgecolor='white', linewidth=1.5,  # ⬅️ 同样设置
        s=point_size, marker='o'
    )

    ax.contour(x1, x2, mean_pred, levels=10, colors='k', linewidths=0.5, alpha=0.5)

    cbar = fig.colorbar(sc_train, ax=ax, shrink=0.9)
    cbar.set_label("Predicted Mean", fontsize=24)
    cbar.ax.tick_params(labelsize=14)

    # 过滤掉 NaN 和 Inf 值来设置轴范围
    x_data = arr_x_grids[:, 0]
    y_data = arr_x_grids[:, 1]
    
    valid_x_data = x_data[np.isfinite(x_data)]
    valid_y_data = y_data[np.isfinite(y_data)]
    
    if len(valid_x_data) > 0 and len(valid_y_data) > 0:
        # 计算范围并添加10%的边距
        x_margin = (valid_x_data.max() - valid_x_data.min()) * 0.
        y_margin = (valid_y_data.max() - valid_y_data.min()) * 0.
        
        ax.set_xlim(valid_x_data.min() - x_margin, valid_x_data.max() + x_margin)
        ax.set_ylim(valid_y_data.min() - y_margin, valid_y_data.max() + y_margin)
    else:
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)

    ax.set_xlabel(dict_info["x_axis_name"], fontsize=24)
    ax.set_ylabel(dict_info["y_axis_name"], fontsize=24)

    ax.set_box_aspect(1)
    
    # 增大刻度数字的字体大小
    ax.tick_params(axis='both', which='major', labelsize=18)
    ax.tick_params(axis='both', which='minor', labelsize=18)

    fig.tight_layout()
    fig.savefig(savepath, dpi=300)
    plt.close(fig)















def plot_regression_all_samples_toydata(
    X_train: np.ndarray,
    y_train: np.ndarray,
    arr_grids: np.ndarray,
    f_mu: np.ndarray,
    high_bd_map: np.ndarray,
    low_bd_map: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
        dict_info: dict,
    savepath: str,
):
    '''

    plot all samples with CI in the training set and testing set

    Parameters
    ----------
    X_train: input features of training samples
    y_train: output of testing samples
            high_bd_map,
            low_bd_map,
    X_test: input features of testing samples
    y_test: output of testing samples
    savepath: output path

    Returns
    -------

    '''
    # sort_idx_train = np.argsort(X_train)
    # X_train = X_train[sort_idx_train]
    # y_train = y_train[sort_idx_train]
    # f_mu_train = f_mu_train[sort_idx_train]
    # f_std_train = f_std_train[sort_idx_train]
    #
    # sort_idx_test = np.argsort(X_test)
    # X_test = X_test[sort_idx_test]
    # y_test = y_test[sort_idx_test]


    fig = plt.figure()
    ax = fig.add_subplot()
    from scipy.signal import savgol_filter

    ax.scatter(X_train.flatten(), y_train.flatten(), alpha=0.4, color="mediumspringgreen")
    ax.scatter(X_test.flatten(), y_test.flatten(), alpha=0.4, color="mediumpurple")

    #f_mu = savgol_filter(f_mu, 10, 3)
    ax.plot(arr_grids, f_mu, label=r"$\mathbb{E}[f]$", color='darkred')

    if high_bd_map is not None and low_bd_map is not None:
        # high_bd_map = savgol_filter(high_bd_map, 10, 3)
        # low_bd_map = savgol_filter(low_bd_map, 10, 3)

        ax.fill_between(
            arr_grids,
            low_bd_map,
            high_bd_map,
            alpha=0.3,
            color="tab:grey",
            label=r"$2\sqrt{\mathbb{V}\,[y]}$",
        )
        ax.plot(arr_grids, low_bd_map, linewidth=0.3, color='black')
        ax.plot(arr_grids, high_bd_map, linewidth=0.3, color='black')

    # if dict_info["y_axis_name"] == "Y":
    #     if dict_info["x_axis_name"] == "C1":
    #
    #         # x2 = torch.exp(x1 * 0.2)  # x in [2.5, 12.5]
    #         # f1 = torch.sin(x1 * 0.3) * 10
    #         # f2 = x2
    #
    #         g1 = (np.sin(0.2 * (arr_grids + 1) * dict_info["pos"]) + np.cos(0.4 * (arr_grids + 1))) * dict_info["pos"]
    #         ax.plot(arr_grids, g1, color='blue')
    #         sigma1 = (0.1 + 0.1 * arr_grids) * dict_info["pos"] * 0.1
    #         ax.fill_between(arr_grids, g1 - sigma1 * 2,  g1 + sigma1 * 2, color='blue', alpha=0.3)
    #
    #
    #         # f1 = np.sin(arr_grids * 0.3) * 10
    #         # ax.plot(arr_grids, f1, color='orange')
    #
    #     elif dict_info["x_axis_name"] == "C2":
    #         g2 = np.exp(0.1 * (arr_grids + dict_info["pos"] + 1)) - np.log((arr_grids + 1) / (dict_info["pos"]  + 1) * 10 + 1)
    #         ax.plot(arr_grids, g2* dict_info["pos"], color='blue')
    #         sigma2 = (0.1 + 0.2 * np.log(arr_grids + 1) + 0.5)* dict_info["pos"] * 0.1
    #         ax.fill_between(arr_grids, g2 - sigma2 * 2,  g2 + sigma2 * 2, color='blue', alpha=0.3)
    #
    #
    #
    #     elif dict_info["x_axis_name"] == "C3":
    #         g3 = np.exp(0.1 * (arr_grids + dict_info["pos"] + 1)) - np.log((arr_grids + 1) / (dict_info["pos"] + 1) * 10 + 1)
    #         ax.plot(arr_grids, g3* dict_info["pos"], color='blue')
    #         sigma3 = (0.1 + 0.2 * np.log(arr_grids + 1) + 0.5)* dict_info["pos"] * 0.1
    #         ax.fill_between(arr_grids, g3 - sigma3 * 2,  g3 + sigma3 * 2, color='blue', alpha=0.3)
    #
    #


            # f2 = arr_grids
            # ax.plot(arr_grids, f2, color='orange')

    plt.xlabel(dict_info["x_axis_name"])
    plt.ylabel(dict_info["y_axis_name"])
    ax.set_box_aspect(1)
    plt.savefig(savepath, bbox_inches="tight", pad_inches = 0.1)
    plt.close()