
from matplotlib.colors import LinearSegmentedColormap, Normalize
from matplotlib.cm import ScalarMappable
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

def multiple_figure_sensitivity():
    surfaces = {
        "Energy": np.array([
            [3.14, 3.12, 3.09, 3.09, 3.10],
            [3.14, 3.11, 3.09, 3.07, 3.06],
            [3.13, 3.10, 3.07, 3.03, 3.05],
            [3.13, 3.10, 3.05, 3.00, 2.97],
            [3.03, 2.96, 2.96, 2.96, 2.97],
        ]).T,
        "Bimatch": np.array([
            [2.85, 2.84, 2.83, 2.83, 2.83],
            [2.83, 2.81, 2.80, 2.80, 2.80],
            [2.81, 2.79, 2.76, 2.75, 2.76],
            [2.79, 2.76, 2.73, 2.73, 2.73],
            [2.73, 2.73, 2.71, 2.69, 2.69],
        ]).T,
        "Knapsack": np.array([
            [3.76, 3.74, 3.73, 3.73, 3.72],
            [3.74, 3.72, 3.70, 3.71, 3.70],
            [3.74, 3.72, 3.70, 3.68, 3.67],
            [3.74, 3.70, 3.65, 3.63, 3.63],
            [3.75, 3.70, 3.65, 3.63, 3.63],

        ]).T,
        "BudgetAlloc": np.array([
            [2.10, 2.08, 2.07, 2.08, 2.06],
            [2.09, 2.08, 2.06, 2.05, 2.04],
            [2.09, 2.07, 2.05, 2.03, 2.02],
            [2.09, 2.07, 2.04, 2.02, 2.00],
            [2.09, 2.07, 2.05, 2.02, 1.98],
        ]).T,
        "Portfolio": np.array([
            [2.19, 2.16, 2.14, 2.13, 2.14],
            [2.18, 2.16, 2.10, 2.09, 2.10],
            [2.18, 2.14, 2.09, 2.07, 2.08],
            [2.17, 2.13, 2.05, 2.03, 2.05],
            [2.17, 2.14, 2.09, 2.05, 2.05],
        ]).T,
        "Cubic": np.array([
            [43.65, 43.65, 43.65, 43.67, 43.68],
            [43.56, 43.55, 43.57, 43.54, 43.54],
            [43.43, 43.46, 43.47, 43.48, 43.48],
            [43.22, 43.26, 43.26, 43.22, 43.25],
            [43.22, 43.12, 43.12, 43.13, 43.18],

        ]).T,
    }

    for item, dataset in enumerate(surfaces):
        print(dataset)
        figure_sensitivity(dataset, surfaces[dataset])
        figure_contour_plot(dataset, surfaces[dataset])

def figure_sensitivity(dataset, regret_surface):


    plt.rcParams.update({
        # "text.usetex": True,
        "font.family": "serif",
        "axes.labelsize": 12,
        "axes.titlesize": 13,
        "legend.fontsize": 11,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
    })

    # Define the grid and regret surface
    gamma_down_vals = [0.98, 0.95, 0.9, 0.8, 0.5]
    gamma_up_vals = [1.02, 1.05, 1.1, 1.2, 2]
    G_down, G_up = np.meshgrid(gamma_down_vals, gamma_up_vals)
    # regret_surface

    # Custom colormap from user-provided palette
    palette = ['#bed4ed', '#c8d6e7', '#ede3c8', '#fbf3d6', '#f4dcda', '#eac0c0']
    cmap = LinearSegmentedColormap.from_list("vfe_cmap", palette, N=256)
    norm = Normalize(vmin=regret_surface.min(), vmax=regret_surface.max())
    facecolors = cmap(norm(regret_surface))

    # Find the best point
    min_idx = np.unravel_index(np.argmin(regret_surface), regret_surface.shape)
    best_gamma_down = G_down[min_idx]
    best_gamma_up = G_up[min_idx]
    best_regret = regret_surface[min_idx]
    # Get corresponding colormap color and darken it slightly
    min_color = cmap(norm(best_regret))
    min_color_darker = (min_color[0] * 0.7, min_color[1] * 0.7, min_color[2] * 0.7)

    # Plotting
    fig = plt.figure(figsize=(5.3, 3.8))
    ax = fig.add_subplot(111, projection='3d')

    # Surface with facecolors
    ax.plot_surface(G_down, G_up, regret_surface, facecolors=facecolors,
                    rstride=1, cstride=1, linewidth=0, antialiased=False, shade=False)

    # Wireframe
    ax.plot_wireframe(G_down, G_up, regret_surface, color="#5a8bb0", linewidth=0.8)

    # Best point with corresponding darkened color and a marker
    ax.scatter(best_gamma_down, best_gamma_up, best_regret,
               color=min_color_darker, edgecolor='black', s=80, label="Minimum Regret", marker='o', zorder=10)

    ax.text(best_gamma_down, best_gamma_up, best_regret - 0.04,
            f"({best_gamma_down:.2f}, {best_gamma_up:.2f})\nRegret: {best_regret:.2f}",
            fontsize=9, ha='center', va='bottom', color='black', zorder=10)
    # ax.set_zlim(regret_surface.min() - 0.1, regret_surface.max() + 0.1)
    # Axis labels and title
    ax.set_xlabel(r"$\gamma_\downarrow$") # , labelpad=10
    ax.set_ylabel(r"$\gamma_\uparrow$") # , labelpad=10
    # ax.set_zlabel("Regret", labelpad=8)
    # ax.set_title("VFE Parameter Sensitivity (Heat Surface)", pad=12)

    # Set background pane colors to light gray-blue
    light_background = (0.94, 0.95, 0.97)
    ax.xaxis.pane.set_facecolor(light_background)
    ax.yaxis.pane.set_facecolor(light_background)
    ax.zaxis.pane.set_facecolor(light_background)
    ax.grid(True)

    # View and legend
    ax.view_init(elev=25, azim=135)
    ax.legend(
        loc="upper left",
        bbox_to_anchor=(0.8, 0.95)  # 数据坐标
    )
    fig.subplots_adjust(
        top=1.0,
        bottom=0.062,
        left=0.0,
        right=0.942,
        hspace=0.18,
        wspace=0.215
    )
    # Colorbar
    mappable = ScalarMappable(norm=norm, cmap=cmap)
    mappable.set_array([])
    fig.colorbar(mappable, shrink=0.6, aspect=15, pad=0.1, label="Regret")

    # plt.tight_layout()
    plt.savefig(f"./vfe_gamma_sensitivity_heat_{dataset}.pdf", bbox_inches='tight')
    plt.savefig(f"./vfe_gamma_sensitivity_heat_{dataset}.png", bbox_inches='tight')
    plt.show()

def figure_contour_plot(dataset, regret_surface):
    plt.rcParams.update({
        # "text.usetex": True,
        "font.family": "serif",
        "axes.labelsize": 12,
        "axes.titlesize": 13,
        "legend.fontsize": 11,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
    })

    # 数据准备
    gamma_down_vals = [0.80, 0.85, 0.90, 0.95, 0.98]
    gamma_up_vals = [1.02, 1.03, 1.05, 1.07, 1.10]
    G_down, G_up = np.meshgrid(gamma_down_vals, gamma_up_vals)

    # regret_surface = np.array([
    #     [3.26, 3.14, 3.07, 3.11, 3.18],
    #     [3.21, 3.08, 3.02, 3.06, 3.12],
    #     [3.18, 3.05, 2.96, 3.01, 3.08],
    #     [3.19, 3.06, 2.98, 3.03, 3.10],
    #     [3.24, 3.11, 3.04, 3.08, 3.15],
    # ]).T

    palette = ['#bed4ed', '#c8d6e7', '#ede3c8', '#fbf3d6', '#f4dcda', '#eac0c0']
    cmap = LinearSegmentedColormap.from_list("vfe_cmap", palette, N=256)

    min_idx = np.unravel_index(np.argmin(regret_surface), regret_surface.shape)
    best_gamma_down = G_down[min_idx]
    best_gamma_up = G_up[min_idx]
    best_regret = regret_surface[min_idx]

    fig1, ax1 = plt.subplots(figsize=(6, 5))
    contour = ax1.contourf(G_down, G_up, regret_surface, levels=20, cmap=cmap)
    ax1.plot(best_gamma_down, best_gamma_up, 'o', color='red', markersize=7, markeredgecolor='black')


    ax1.set_xlabel(r"$\gamma_\downarrow$")
    ax1.set_ylabel(r"$\gamma_\uparrow$")
    ax1.text(best_gamma_down-0.005, best_gamma_up+0.003,  # + 0.003
             f"({best_gamma_down:.2f}, {best_gamma_up:.2f})\nRegret: {best_regret:.2f}",
             ha='center', va='bottom', fontsize=9)
    # ax1.set_title("VFE Parameter Sensitivity (Contour Plot)")
    fig1.colorbar(contour, ax=ax1, label="Regret")
    plt.tight_layout()
    plt.savefig(f"./vfe_gamma_contour_plot_{dataset}.pdf", bbox_inches='tight')
    plt.savefig(f"./vfe_gamma_contour_plot_{dataset}.png", bbox_inches='tight')
    plt.show()





def figure_pretrain_ablation():
    import matplotlib.pyplot as plt
    import numpy as np

    # Sample data (replace with actual results)
    pretrain_steps = [0, 5, 10, 20]
    datasets = ["Energy", "BiMatch", "Knapsack", "BudgetAlloc", "Portfolio", "Cubic"]
    methods = ["SPO+", "DFL", "BB", "NCE", "PointLTR", "PairLTR", "ListLTR"]

    # Dict: dataset -> method -> regret list (len=5)
    regret_data = {
        "Energy": {
            "SPO+": [5.50, 5.00, 4.38, 3.84],
            "DFL": [5.30, 4.54, 4.59, 3.99],
            "BB": [6.50, 5.99, 5.08, 5.18],
            "NCE": [4.80, 4.48, 3.85, 3.37],
            "PointLTR": [4.70, 4.08, 3.81, 3.49],
            "PairLTR": [4.90, 4.38, 3.96, 3.69],
            "ListLTR": [4.60, 3.97, 3.72, 3.33],
            "VFE": [2.96, 2.96, 2.96, 2.96]
        },
        "BiMatch": {
            "SPO+": [2.75, 2.92, 2.79, 2.94],
            "DFL": [2.90, 2.69, 2.82, 2.85],
            "BB": [3.10, 3.03, 2.99, 3.12],
            "NCE": [3.05, 2.84, 3.29, 3.18],
            "PointLTR": [3.00, 3.15, 3.10, 3.11],
            "PairLTR": [2.95, 2.88, 2.77, 3.12],
            "ListLTR": [2.93, 2.73, 2.84, 2.85],
            "VFE": [2.69, 2.69, 2.69, 2.69]
        },
        "Knapsack": {
            "SPO+": [6.19, 6.50, 6.83, 7.17],
            "DFL": [11.80, 10.03, 8.53, 7.25],
            "BB": [24.50, 20.82, 17.70, 15.04],
            "NCE": [13.60, 11.56, 9.83, 8.36],
            "PointLTR": [6.55, 6.88, 7.22, 7.58],
            "PairLTR": [7.70, 8.09, 8.49, 8.91],
            "ListLTR": [6.05, 6.35, 6.67, 7.00],
            "VFE": [3.63, 3.63, 3.63, 3.63]
        },
        "BudgetAlloc": {
            "SPO+": [1.50, 1.46, 1.42, 1.40],
            "DFL": [36.22, 35.13, 34.41, 33.68],
            "BB": [26.70, 25.90, 25.36, 24.83],
            "NCE": [10.10, 9.80, 9.59, 9.39],
            "PointLTR": [69.00, 66.93, 65.55, 64.17],
            "PairLTR": [5.90, 5.72, 5.61, 5.49],
            "ListLTR": [5.85, 5.67, 5.56, 5.44],
            "VFE": [1.98, 1.98, 1.98, 1.98]
        },
        "Portfolio": {
            "SPO+": [3.20, 2.88, 2.69, 2.19],
            "DFL": [3.60, 3.28, 2.87, 2.48],
            "BB": [3.80, 3.59, 3.33, 2.95],
            "NCE": [3.20, 2.82, 2.53, 2.44],
            "PointLTR": [3.05, 2.73, 2.42, 2.26],
            "PairLTR": [3.40, 2.90, 2.96, 2.42],
            "ListLTR": [3.00, 2.75, 2.43, 2.23],
            "VFE": [2.03, 2.03, 2.03, 2.03]
        },
        "Cubic": {
            "SPO+": [160.00, 144.00, 140.44, 141.37],
            "DFL": [1.97, 1.77, 1.75, 1.80],
            "BB": [13.94, 12.55, 12.28, 12.78],
            "NCE": [160.00, 144.00, 141.40, 139.73],
            "PointLTR": [5.07, 4.56, 4.63, 4.44],
            "PairLTR": [0.19, 0.17, 0.17, 0.17],
            "ListLTR": [0.17, 0.15, 0.15, 0.15],
            "VFE": [41.65, 41.65, 41.65, 41.65]
        }
    }

    colors = [
    '#6396d8',  # #bed4ed
    '#8393b7',  # #c8d6e7
    '#c9b376',  # #ede3c8
    '#e1c05a',  # #fbf3d6
    '#d98c8c',  # #f4dcda
    '#c97979',  # #eac0c0
    '#63a88f',  # #b9d8c2
    '#9c7bc1',  # #d0b8e3
]



    # axs = axs.flatten()

    for i, dataset in enumerate(datasets):
        fig, axs = plt.subplots(figsize=(5, 3))
        # ax = axs[i]
        ax = axs
        for j, method in enumerate(methods):
            ax.plot(pretrain_steps, regret_data[dataset][method], label=method,
                    color=colors[j], marker='o', linewidth=2)
        # VFE flat line
        ax.plot(pretrain_steps, regret_data[dataset]["VFE"], label="VFE (anneal)",
                color=colors[3], linestyle='--', linewidth=2)
        ax.set_title(dataset)
        ax.set_xlabel("Pretraining Steps")
        ax.set_ylabel("Regret")
        ax.grid(True, linestyle="--", alpha=0.5)

        if i == len(dataset):
            ax.legend(loc="upper right", bbox_to_anchor=(1, 1))
        plt.tight_layout()
        plt.savefig(f"mainpaper_sensitivity_{dataset}.pdf")
        plt.savefig(f"mainpaper_sensitivity_{dataset}.png")
        plt.show()
        plt.close()

def figure_vfe_loss():
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    import os

    file_path = "loss_log.xlsx"
    df = pd.read_excel(file_path)


    datasets = ["Bimatching", "budget", "Knapsack", "Energy", "Portfolio", "Cubic"]


    dataset_cols = {
        "Bimatching": 0,
        "Budget": 5,
        "Cubic": 10,
        "Knapsack": 15,
        "Energy": 20,
        "Portfolio": 25,
    }


    loss_data = {}
    for ds, start_col in dataset_cols.items():
        epoch_col = df.columns[start_col]  # 'epoch'
        loss_col = df.columns[start_col + 2]  # 'loss'
        pred_col = df.columns[start_col + 3]  # 'pred_loss'


        sub = df[[epoch_col, loss_col, pred_col]].dropna()


        ep_raw = sub[epoch_col].astype(str).str.extract(r'(\d+)')[0]
        epochs = pd.to_numeric(ep_raw, errors='coerce')
        decision_loss = pd.to_numeric(sub[loss_col], errors='coerce')
        prediction_loss = pd.to_numeric(sub[pred_col], errors='coerce')


        valid_mask = (~epochs.isna()) & (~decision_loss.isna()) & (~prediction_loss.isna())
        epochs = epochs[valid_mask].astype(int).reset_index(drop=True)
        decision_loss = decision_loss[valid_mask].reset_index(drop=True)
        prediction_loss = prediction_loss[valid_mask].reset_index(drop=True)

        loss_data[ds] = (epochs, decision_loss, prediction_loss)


    import matplotlib.pyplot as plt



    plt.style.use("seaborn-whitegrid")

    color_pred = "#397ebc"
    color_dec = "#d1495b"
    color_eval = "#55a868"

    for ds, (ep, d_loss, p_loss) in loss_data.items():
        fig, ax1 = plt.subplots(figsize=(6, 4))


        ax1.plot(ep, p_loss, label="Prediction Loss", color=color_pred, marker="o", linewidth=2, linestyle="--")
        ax1.set_ylabel("Prediction Loss", color=color_pred, fontsize=11)
        ax1.tick_params(axis='y', labelcolor=color_pred)


        ax2 = ax1.twinx()
        ax2.plot(ep, d_loss, label="Decision Loss", color=color_dec, marker="s", linewidth=2)
        ax2.set_ylabel("Decision Loss", color=color_dec, fontsize=11)
        ax2.tick_params(axis='y', labelcolor=color_dec)


        ax1.set_xlabel("Epoch", fontsize=11)
        ax1.set_title(f"{ds}", fontsize=12)
        ax1.grid(True, alpha=0.3)


        lines, labels = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        fig.legend(lines + lines2, labels + labels2, loc='upper center', ncol=2, fontsize=10, frameon=False)

        fig.tight_layout(rect=[0, 0, 1, 0.90])
        plt.savefig(f"vfe_loss_dual_axis_{ds}.pdf", dpi=300)
        plt.savefig(f"vfe_loss_dual_axis_{ds}.png", dpi=300)
        plt.show()



def figure_vfe_loss_three():
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    plt.rcParams.update({
        # "text.usetex": True,
        "font.family": "serif",
        "axes.labelsize": 12,
        "axes.titlesize": 13,
        "legend.fontsize": 11,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
    })


    file_path = "loss_log.xlsx"
    df = pd.read_excel(file_path)


    datasets = ["Bimatching", "budget", "Knapsack", "Energy", "Portfolio", "Cubic"]


    dataset_cols = {
        "Bimatching": 0,
        "Budget": 5,
        "Cubic": 10,
        "Knapsack": 15,
        "Energy": 20,
        "Portfolio": 25,
    }


    loss_data = {}
    for ds, start_col in dataset_cols.items():
        epoch_col = df.columns[start_col]  # 'epoch'
        loss_col = df.columns[start_col + 2]  # 'loss'
        pred_col = df.columns[start_col + 3]  # 'pred_loss'


        sub = df[[epoch_col, loss_col, pred_col]].dropna()


        ep_raw = sub[epoch_col].astype(str).str.extract(r'(\d+)')[0]
        epochs = pd.to_numeric(ep_raw, errors='coerce')
        decision_loss = pd.to_numeric(sub[loss_col], errors='coerce')
        prediction_loss = pd.to_numeric(sub[pred_col], errors='coerce')


        valid_mask = (~epochs.isna()) & (~decision_loss.isna()) & (~prediction_loss.isna())
        epochs = epochs[valid_mask].astype(int).reset_index(drop=True)
        decision_loss = decision_loss[valid_mask].reset_index(drop=True)
        prediction_loss = prediction_loss[valid_mask].reset_index(drop=True)

        loss_data[ds] = (epochs, decision_loss, prediction_loss)


    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import host_subplot
    import mpl_toolkits.axisartist as AA


    plt.style.use("seaborn-whitegrid")


    color_pred = "#397ebc"
    color_dec = "#d1495b"
    color_eval = "#55a868"

    for ds, start_col in dataset_cols.items():

        epoch_col = df.columns[start_col]
        loss_col = df.columns[start_col + 2]
        pred_col = df.columns[start_col + 3]
        eval_col = df.columns[start_col + 4]


        sub = df[[epoch_col, loss_col, pred_col, eval_col]].dropna()
        ep_str = sub[epoch_col].astype(str).str.extract(r'(\d+)')[0]
        epochs = pd.to_numeric(ep_str, errors='coerce')
        decision_loss = pd.to_numeric(sub[loss_col], errors='coerce')
        prediction_loss = pd.to_numeric(sub[pred_col], errors='coerce')
        eval_loss = pd.to_numeric(sub[eval_col], errors='coerce')

        valid_mask = (~epochs.isna()) & (~decision_loss.isna()) & (~prediction_loss.isna()) & (~eval_loss.isna())
        epochs = epochs[valid_mask].astype(int).reset_index(drop=True)
        decision_loss = decision_loss[valid_mask].reset_index(drop=True)
        prediction_loss = prediction_loss[valid_mask].reset_index(drop=True)
        eval_loss = eval_loss[valid_mask].reset_index(drop=True)


        fig = plt.figure(figsize=(8, 4))
        host = host_subplot(111, axes_class=AA.Axes)


        par1 = host.twinx()
        par2 = host.twinx()


        offset = 40
        new_fixed_axis = par2.new_fixed_axis
        par2.axis["right"] = new_fixed_axis(loc="right", offset=(offset, 0))
        par2.axis["right"].toggle(all=True)
        par1.axis["right"].toggle(all=True)


        p1, = host.plot(epochs, prediction_loss, color=color_pred, linewidth=2, marker='o', linestyle="--",
                        label="Prediction Loss")
        p2, = par1.plot(epochs, decision_loss, color=color_dec, linewidth=2, marker='s', label="Decision Loss")
        p3, = par2.plot(epochs, eval_loss, color=color_eval, linewidth=2, marker='^', linestyle=":", label="Eval")


        host.set_ylabel("Prediction Loss", color=color_pred)
        par1.set_ylabel("Decision Loss", color=color_dec)
        par2.set_ylabel("Eval", color=color_eval)

        host.yaxis.label.set_color(color_pred)
        par1.yaxis.label.set_color(color_dec)
        par2.yaxis.label.set_color(color_eval)

        host.tick_params(axis='y', colors=color_pred)
        par1.tick_params(axis='y', colors=color_dec)
        par2.tick_params(axis='y', colors=color_eval)

        host.set_xlabel("Epoch")
        host.set_title(ds)


        lines = [p1, p2, p3]
        host.legend(lines, [l.get_label() for l in lines], loc='upper center', ncol=3, frameon=False)


        fig.tight_layout(rect=[0.03, 0, 0.97, 1])
        plt.savefig(f"vfe_loss_dual_axis_{ds}.pdf", dpi=300)
        plt.savefig(f"vfe_loss_dual_axis_{ds}.png", dpi=300)
        plt.show()



if __name__ == '__main__':
    multiple_figure_sensitivity()
    figure_vfe_loss_three()
    figure_pretrain_ablation()