import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import matplotlib as mpl
import numpy as np
from matplotlib.ticker import FormatStrFormatter

spec_number_list = [0, 1, 2, 3, 4, 5, 10, 15, 20, 40, 60, 80, 100]
top20_lt_normal_probability = [np.float64(0.011344784241435783), np.float64(0.016183589861112912), np.float64(0.016239648069022224), np.float64(0.01583764810006999), np.float64(0.017403921824048405), np.float64(0.01671606460040818), np.float64(0.01804786221629391), np.float64(0.01828043563617809), np.float64(0.02051260178060537), np.float64(0.020056214722488742), np.float64(0.0192419716263456), np.float64(0.018893782539733884), np.float64(0.020234249884004903)]
top20_normal_lt_probability = [np.float64(0.002183218383978393), np.float64(0.004120775532042639), np.float64(0.004057900456126455), np.float64(0.004219308127534792), np.float64(0.0039786676743096905), np.float64(0.003636938404654627), np.float64(0.0037326745128146805), np.float64(0.004676485491550661), np.float64(0.005601457484847612), np.float64(0.009066341657375493), np.float64(0.012081375486606183), np.float64(0.01548314361220816), np.float64(0.020234249884004903)]
top20_lt_lt_probability = [np.float64(0.0030461776628050888), np.float64(0.004207843946205685), np.float64(0.004398195257521625), np.float64(0.004774633761621771), np.float64(0.004587654129042806), np.float64(0.004274686312209529), np.float64(0.004581009146750148), np.float64(0.005610722630663077), np.float64(0.007571094876147981), np.float64(0.011910419702055931), np.float64(0.014094428199516966), np.float64(0.01727334661633774), np.float64(0.020234249884004903)]

spec_number_list = [0, 1, 2, 3, 4, 5, 10, 15, 20, 40, 60, 80, 100]
top20_lt_normal_acc = [np.float64(0.005714285714285714), np.float64(0.35214285714285715), np.float64(0.38142857142857145), np.float64(0.3942857142857143), np.float64(0.44785714285714284), np.float64(0.4057142857142857), np.float64(0.4735714285714286), np.float64(0.4907142857142857), np.float64(0.6021428571428571), np.float64(0.6128571428571429), np.float64(0.5957142857142858), np.float64(0.5735714285714286), np.float64(0.6492857142857142)]
top20_normal_lt_acc = [np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0014285714285714286), np.float64(0.007857142857142858), np.float64(0.04428571428571428), np.float64(0.24142857142857144), np.float64(0.35928571428571426), np.float64(0.4907142857142857), np.float64(0.6492857142857142)]
top20_lt_lt_acc = [np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0035714285714285713), np.float64(0.0007142857142857143), np.float64(0.02214285714285714), np.float64(0.06642857142857143), np.float64(0.16714285714285715), np.float64(0.4057142857142857), np.float64(0.4514285714285714), np.float64(0.5585714285714286), np.float64(0.6492857142857142)]

spec_number_list_entropy = [0, 1, 2, 3, 4, 5, 10, 15, 20, 40, 60, 80, 100]
top20_lt_normal_entropy = [np.float32(5.6648626), np.float32(3.8422647), np.float32(3.8123982), np.float32(3.9025292), np.float32(3.6237488), np.float32(3.616661), np.float32(3.0400279), np.float32(3.0016918), np.float32(3.0187619), np.float32(2.860466), np.float32(2.826674), np.float32(2.9241517), np.float32(1.8249185)]
top20_normal_lt_entropy = [np.float32(4.1526685), np.float32(4.5011816), np.float32(4.394666), np.float32(4.1888804), np.float32(4.188092), np.float32(4.4936137), np.float32(4.097073), np.float32(4.5800586), np.float32(4.328223), np.float32(3.8746536), np.float32(3.5122585), np.float32(2.9795494), np.float32(1.8249185)]
top20_lt_lt_entropy = [np.float32(6.4173517), np.float32(3.771934), np.float32(3.2715278), np.float32(3.0069776), np.float32(3.0580797), np.float32(2.9762578), np.float32(2.551225), np.float32(2.5741076), np.float32(2.4895694), np.float32(2.0019598), np.float32(2.1331477), np.float32(1.8858362), np.float32(1.8249185)]

spec_number_list = [0, 1, 2, 3, 4, 5, 10, 15, 20, 40, 60, 80, 100]
tail80_lt_normal_probability = [np.float64(0.009663803949473991), np.float64(0.008454102539020434), np.float64(0.008440087957611532), np.float64(0.008540587990355562), np.float64(0.008149019556090431), np.float64(0.008320983846438139), np.float64(0.007988034460721402), np.float64(0.00792989107783699), np.float64(0.007371849590536905), np.float64(0.007485946289852368), np.float64(0.007689507100660811), np.float64(0.007776554375725157), np.float64(0.007441437461694217)]
tail80_normal_lt_probability = [np.float64(0.011954195417553949), np.float64(0.011469806129584217), np.float64(0.011485524887589237), np.float64(0.011445172988162864), np.float64(0.011505333069156094), np.float64(0.011590765387544317), np.float64(0.011566831365740394), np.float64(0.011330878647791646), np.float64(0.011099635660184504), np.float64(0.010233414586377226), np.float64(0.009479656104673008), np.float64(0.008629214104588365), np.float64(0.007441437461694217)]
tail80_lt_lt_probability = [np.float64(0.01173845558825373), np.float64(0.01144803901967472), np.float64(0.011400451174856375), np.float64(0.011306341582014413), np.float64(0.011353086443991092), np.float64(0.011431328415390324), np.float64(0.011354747698408444), np.float64(0.01109731932298538), np.float64(0.010607226281192978), np.float64(0.009522395100763918), np.float64(0.0089763929618344), np.float64(0.008181663302943239), np.float64(0.007441437461694217)]

# mpl.rcParams.update({
#     'font.size': 11,  # 控制整体字体大小
#     'axes.titlesize': 12,
#     'axes.labelsize': 11,
#     'xtick.labelsize': 10,
#     'ytick.labelsize': 10,
#     'legend.fontsize': 8,
#     'lines.markersize': 6,  # 控制标记大小
#     'lines.linewidth': 2,  # 控制线宽
#     'grid.linewidth': 0.2,
#     'grid.alpha': 0.2,
#     # set title size
#     'axes.titlesize': 12,

# })

mpl.rcParams.update({
    'font.size': 11,  # 控制整体字体大小
    'axes.titlesize': 12,
    'axes.labelsize': 12,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 8,
    'lines.markersize': 6,  # 控制标记大小
    'lines.linewidth': 2,  # 控制线宽
    'grid.linewidth': 0.2,
    'grid.alpha': 0.2,
    # set title size
    'axes.titlesize': 12,
})

def refine_figure_with_title_adjustment(data, labels, title, y_labels, figsize=(5.5, 5.5)):
    spec_number_list = [0, 1, 2, 3, 4, 5, 10, 15, 20, 40, 60, 80, 100]

    # Define subplot layout
    fig, axes = plt.subplots(2, 2, figsize=figsize, gridspec_kw={"height_ratios": [1, 1]})
    axes = axes.flatten()

    # Subplot Titles
    # subtitles = [
    #     "(a) Confidence on Tail Classes",
    #     "(b) Confidence on Head Classes",
    #     "(c) Accuracy on Tail Classes",
    #     "(d) Entropy of Label on Tail Classes",
    # ]

    subtitles = [
        "(a) Accuracy on Tail Classes",
        "(b) Entropy of Label on Tail Classes",
        "(c) Confidence on Tail Classes",
        "(d) Confidence on Head Classes",
    ]
    # subtitles = [
    #     "(a)",
    #     "(b)",
    #     "(c)",
    #     "(d)",
    # ]

    # Colors for curves
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c']

    # Plot data with adjustments
    for idx, (ax, y_data, subtitle, y_label) in enumerate(zip(axes, data, subtitles, y_labels)):
        for y, label, color in zip(y_data, labels, colors):
            ax.plot(
                spec_number_list,
                y,
                label=label,
                marker='.',
                color=color,
            )
        ax.set_title(subtitle, fontsize=10)  # Larger subplot titles
        ax.set_ylabel(y_label)  # Larger y-axis label
        if idx == 0 or idx == 1:
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        else:
            from matplotlib.ticker import ScalarFormatter
            ax.yaxis.set_major_formatter(ScalarFormatter())
            ax.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
        ax.set_xlabel("Number")  # Larger x-axis label
        ax.grid(True, linestyle="--", color="#d3d3d3", alpha=0.7)

        # Adjust y-axis for Top-20 and Tail-80 probability plots
        if idx == 2 or idx == 3:
            ax.set_ylim(0, 0.0225)


    # Shared legend outside
    ax.legend(
        labels,
    )
    # plt.legend(fontsize=40, )

    # Adjust title and spacing
    # plt.subplots_adjust(top=0.9)  # Reduce margin between title and plots
    # fig.suptitle(title, fontsize=20, weight="bold")  # Larger main title
    # plt.tight_layout()  # Avoid overlapping with legend
    fig.tight_layout()#调整整体空白
    fig.subplots_adjust(hspace=0.5, wspace=0.33)
    plt.savefig("class_metrics.eps", format="eps", bbox_inches="tight")
    plt.savefig("class_metrics.png", format="png", bbox_inches="tight")
    # plt.show()


# Define data
data = [
    [ top20_lt_normal_probability, top20_normal_lt_probability, top20_lt_lt_probability],
    [ tail80_lt_normal_probability, tail80_normal_lt_probability, tail80_lt_lt_probability],
    [ top20_lt_normal_acc, top20_normal_lt_acc, top20_lt_lt_acc],
    [ top20_lt_normal_entropy, top20_normal_lt_entropy, top20_lt_lt_entropy],
]

# labels = ["lt image with normal label", "normal image with lt label", "lt image with lt label"]
labels = ["Config 1", "Config 2", "Config 3"]

# Define y-axis labels
y_labels = ["Accuracy", "Entropy", "Confidence", "Confidence"]

# Call the function
refine_figure_with_title_adjustment([data[2], data[3], data[0], data[1]], labels, "Class Metrics for Top-20 and Tail-80 Classes", y_labels)