from matplotlib import pyplot as plt, font_manager
import numpy as np

plt.rcParams.update({
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'svg.fonttype': 'path',
    'svg.hashsalt': 'fixed-salt',  # any constant string for reproducibility
})
font = font_manager.FontProperties(size=10)
font_legend = font_manager.FontProperties(size=10)
strip_svg_meta: dict[str, None] = {k: None for k in ('Creator', 'Date', 'Format', 'Type')}
strip_pdf_meta: dict[str, None] = {k: None for k in ('Title', 'Author', 'Subject', 'Keywords', 'Creator', 'Producer', 'CreationDate', 'ModDate', 'Trapped')}


def plot_tpot_ax(ax: plt.Axes) -> None:
    inlier_bitwidths = np.asarray([4, 3, 2])
    outlier_rates = np.asarray([0., .01, .02, .03, .04, .05])
    tpot_bf16 = 0.02583916299045086
    tpot = np.asarray([
        [0.012132660485804081, 0.01135605201125145, 0.011503500863909721],
        [0.0135238291695714, 0.012321893125772476, 0.01190412137657404],
        [0.013811192475259304, 0.012872005812823772, 0.012395647354424],
        [0.01421270426362753, 0.01309228502213955, 0.0126122971996665],
        [0.014793352223932743, 0.013902941718697548, 0.01313760131597519],
        [0.015192641876637936, 0.014111123979091644, 0.013730524107813835],
    ])
    tpot_speedup = tpot_bf16 / tpot

    plot_colors = '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', 'darkgreen', '#000000'

    for bitwidth_i, bitwidth in enumerate(inlier_bitwidths):
        ax.plot(outlier_rates, tpot_speedup[:, bitwidth_i], label=f'{bitwidth}', color=plot_colors[[8, 10, 9][bitwidth_i]], marker='o', zorder=4.)

    ax.axhline(y=1., color='black', linestyle='--', linewidth=1, zorder=2.)
    ax.set_xlim(np.min(outlier_rates), np.max(outlier_rates))
    ax.set_xticks(outlier_rates)
    ax.set_xticklabels([f'{round(o * 100.)}' for o in outlier_rates], rotation=0., ha='center', fontproperties=font)
    ax.set_xlabel(f'Outlier Rate [%]', fontproperties=font)
    ax.set_ylim(0., 2.5)
    ax.set_yticks(np.arange(0, 2.6, .5))
    ax.set_yticklabels([f'{tick:.1f}' for tick in ax.get_yticks()], fontproperties=font)
    ax.set_ylabel('TPOT Speedup vs PyTorch BF16', fontproperties=font)
    ax.grid(True)
    ax.tick_params(axis='both', which='both', length=0)
    ax.legend(title='Inlier Bitwidth [bit]', ncol=2, loc='lower right', framealpha=1., title_fontproperties=font_legend, prop=font_legend)

    ax.set_title('SSQR on Qwen3-8B (A6000 GPU)', fontproperties=font)
    ax.set_facecolor((1., 1., 1., 1.))


def main():
    inf = float('inf')

    ppl_baseline_8b_wiki2 = 9.73
    ppl_baseline_8b_c4 = 13.55

    bitwidth_huffman_8b = np.asarray([4.125, 3.125, 2.125, 1.125])
    ppl_huffman_8b_wiki2 = np.asarray([9.86, 9.88, 10.1, 10.4, 13.97, 3505.29])
    ppl_huffman_8b_c4 = np.asarray([13.14, 13.6, 16.89, 2232.63])

    bitwidth_mse_8b = np.asarray([4.125, 3.125, 2.125, 1.125])
    ppl_mse_8b_wiki2 = np.asarray([10.1, 12.7671, 57.51, inf])
    ppl_mse_8b_c4 = np.asarray([14., 15.61, 36.14, inf])

    bitwidth_rtn_h_8b = np.asarray([4.125, 3.125, 2.125, 1.125])
    ppl_rtn_h_8b_wiki2 = np.asarray([9.9, 10.75, 593.05, inf])
    ppl_rtn_h_8b_c4 = np.asarray([13.8, 14.63, 503, inf])

    bitwidth_rtn_8b = np.asarray([4.125, 3.125, 2.125, 1.125])
    ppl_rtn_8b_wiki2 = np.asarray([10.3, 16.3, 2020014720, inf])
    ppl_rtn_8b_c4 = np.asarray([15.2, 21.08, 2246669824, inf])

    bitwidth_huffman_06b = np.asarray([4.125, 3.125, 2.125, 1.125])
    ppl_huffman_06b_wiki2 = np.asarray([22.72, 31.43, 156.45, 2768744])
    ppl_huffman_06b_c4 = np.asarray([28.35, 37.92, 171.38, 2129788.5])

    ppl_huffman_1_7b_wiki2 = np.asarray([18.1788, 19.72, 46.94, inf])
    ppl_huffman_1_7b_c4 = np.asarray([20.9924, 23.1487, 51.9566, inf])

    ppl_huffman_4b_wiki2 = np.asarray([14.2565, 14.55, 24.4, inf])
    ppl_huffman_4b_c4 = np.asarray([17.3858, 18.1674, 26.4568, inf])

    bitwidth_huffman_14b = np.asarray([4.125, 3.125, 2.125, 1.125])
    ppl_huffman_14b_wiki2 = np.asarray([8.76, 9.06, 11.36, 878])
    ppl_huffman_14b_c4 = np.asarray([12.12, 13.97, 15.5, 426.26])

    bitwidth_outlier_8b = np.asarray([4.125, 3.125, 2.125])
    outlier_percentages = np.asarray([.01, .02, .03, .04, .05])
    ppl_outlier_8b_wiki2 = np.asarray([[10, 10.64, 22.3], [9.96, 10.57, 16.55], [9.92, 10.42, 14.05], [9.84, 10.34, 13.12], [9.8, 10.32, 12.88]])
    ppl_outlier_8b_c4 = np.asarray([[13.83, 14.71, 27.07], [13.76, 14.56, 20.8], [13.76, 14.32, 18.57], [13.71, 14.29, 17.6], [13.67, 14.22, 16.85]])
    outlier_percentages = np.asarray([.01, .03, .05])
    ppl_outlier_8b_wiki2 = np.asarray([[10, 10.64, 22.3], [9.92, 10.42, 14.05], [9.8, 10.32, 12.88]])

    n_params = np.asarray([.6, 1.7, 4, 8., 14.])

    colors = '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#000000'

    fig, (ax_methods, ax_scaling_law, ax_speedup) = plt.subplots(1, 3, figsize=(10., 3.5), gridspec_kw={'width_ratios': [2., 1., 1.]})

    plot_metric: str = 'WikiText-2'

    # new dot lines
    bitwidth_huffman_8b_2 = np.asarray([4.445, 4.125, 3.445, 3.125, 2.125, 1.125])
    ppl_huffman_8b_wiki2_2 = np.asarray([9.88, 10.4, 13.97, 3505.29])

    # plot_metric: str = 'C4'
    match plot_metric:
        case 'WikiText-2':
            ax_methods.plot(bitwidth_rtn_8b[:-1], ppl_rtn_8b_wiki2[:-1], label='RTN', color=colors[0], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot(bitwidth_mse_8b[:-1], ppl_mse_8b_wiki2[:-1], label='GPTQ', color=colors[1], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot(bitwidth_rtn_h_8b[:-1], ppl_rtn_h_8b_wiki2[:-1], label='HRTN', color=colors[2], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot(bitwidth_huffman_8b_2[:-1], ppl_huffman_8b_wiki2[:-1], label='HPTQ', color=colors[3], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot([-100., 100.], [ppl_baseline_8b_wiki2] * 2, label='BF16', color='yellow', linestyle='-', zorder=2.)

            ppl_data = np.stack([ppl_huffman_06b_wiki2, ppl_huffman_1_7b_wiki2, ppl_huffman_4b_wiki2, ppl_huffman_8b_wiki2_2, ppl_huffman_14b_wiki2], axis=-1)

            for i in range(len(outlier_percentages)):
                ax_methods.plot(bitwidth_outlier_8b + 32. * outlier_percentages[i], ppl_outlier_8b_wiki2[i], label=f'SSQR-{int(outlier_percentages[i] * 100.)}%', color=(i / (len(outlier_percentages) + 0),) * 3, linestyle='--', marker='o', zorder=3.)

        case 'C4':
            ax_methods.plot(bitwidth_rtn_8b[:-1], ppl_rtn_8b_c4[:-1], label='RTN', color=colors[0], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot(bitwidth_mse_8b[:-1], ppl_mse_8b_c4[:-1], label='GPTQ', color=colors[1], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot(bitwidth_rtn_h_8b[:-1], ppl_rtn_h_8b_c4[:-1], label='HRTN', color=colors[2], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot(bitwidth_huffman_8b[:-1], ppl_huffman_8b_c4[:-1], label='HPTQ', color=colors[3], linestyle='-', marker='o', zorder=4.)
            ax_methods.plot([-100., 100.], [ppl_baseline_8b_c4] * 2, label='BF16 Baseline', color='yellow', linestyle='-', zorder=2.)

            ppl_data = np.stack([ppl_huffman_06b_c4, ppl_huffman_1_7b_c4, ppl_huffman_4b_c4, ppl_huffman_8b_c4, ppl_huffman_14b_c4], axis=-1)
        case _:
            raise NotImplementedError

    for i in range(len(bitwidth_huffman_8b[:-1])):
        ax_scaling_law.plot(bitwidth_huffman_8b[i] * n_params / 8., ppl_data[i], label=f'{bitwidth_huffman_8b[i]}', color=colors[[4, 5, 6][i]], linestyle='-', marker='o')

    ax_methods.set_xlim(2., 6.)
    ax_methods.set_xticks(bitwidth_rtn_8b[:-1].tolist() + [5.125])
    ax_methods.set_xticklabels([f'{v}' for v in ax_methods.get_xticks()], fontproperties=font)
    ax_methods.set_yscale('log')
    ax_methods.set_ylim(9., 80.)
    ax_methods.set_yticks([10., 20., 40., 80.])
    ax_methods.set_yticklabels([f'{int(tick)}' for tick in ax_methods.get_yticks()], fontproperties=font)
    ax_methods.minorticks_off()
    ax_methods.tick_params(axis='both', which='both', length=0.)
    ax_methods.grid()
    ax_methods.legend(ncol=1, loc='upper right', framealpha=1., prop=font_legend)
    ax_methods.set_xlabel('Average Bitwidth [bit]', fontproperties=font)
    ax_methods.set_ylabel(f'{plot_metric} Perplexity', fontproperties=font)
    ax_methods.set_title(f'Different Methods on Qwen3-8B', fontproperties=font)
    ax_methods.set_facecolor((1., 1., 1., 1.))

    axins = ax_methods.inset_axes(bounds=[.35, .49, .26, .47])
    for line in ax_methods.get_lines():
        axins.plot(
            line.get_xdata(),
            line.get_ydata(),
            color=line.get_color(),
            linestyle=line.get_linestyle(),
            marker=line.get_marker(),
            markersize=line.get_markersize(),
            zorder=line.get_zorder(),
            # label=line.get_label(),
        )
    axins.set_xlim(3., 5.25)
    axins.set_xticks([3.125, 4.125, 5.125])
    axins.set_xticklabels([f'{v}' for v in axins.get_xticks()], fontproperties=font)
    axins.set_yscale('log')
    axins.set_ylim(9.7, 11.)
    axins.set_yticks([10., 11.])
    axins.minorticks_off()
    axins.set_yticklabels([f'{int(tick)}' for tick in axins.get_yticks()], fontproperties=font)
    axins.tick_params(axis='both', which='both', length=0.)
    axins.grid()

    ax_scaling_law.set_xlim(0., 8.)
    ax_scaling_law.set_xticks(range(9))
    ax_scaling_law.set_xticklabels([f'{v}' for v in range(9)], fontproperties=font)
    # ax_scaling_law.set_ylim(5., 40.)
    ax_scaling_law.set_yscale('log')
    ax_scaling_law.set_ylim(5., 80.)
    ax_scaling_law.set_yticks([5., 10., 20., 40., 80.])
    ax_scaling_law.set_yticklabels([f'{int(tick)}' for tick in [5., 10., 20., 40., 80.]], fontproperties=font)
    ax_scaling_law.tick_params(axis='both', which='both', length=0.)
    ax_scaling_law.grid()
    ax_scaling_law.legend(title='Average Bitwidth [bit]', ncol=1, framealpha=1., title_fontproperties=font_legend, prop=font_legend)
    ax_scaling_law.text(.5, 8., 'Pareto Optimal', color=colors[5], fontproperties=font)
    ax_scaling_law.set_xlabel('Model Size [GB]', fontproperties=font)
    ax_scaling_law.set_ylabel(f'{plot_metric} Perplexity', fontproperties=font)
    ax_scaling_law.set_title(f'HPTQ on Qwen3-0.6/1.7/4/8/14B', fontproperties=font)
    ax_scaling_law.set_facecolor((1., 1., 1., 1.))

    plot_tpot_ax(ax_speedup)

    labels = ['(a)', '(b)', '(c)']
    for ax, label in zip((ax_methods, ax_scaling_law, ax_speedup), labels):
        ax.annotate(
            label,
            xy=(.5, -.25),
            xycoords='axes fraction',
            ha='center',
            va='top',
            fontproperties=font,
        )

    fig.set_facecolor((1., 1., 1., 0.))
    fig.tight_layout()
    fig.savefig(f'4_application.pdf', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_pdf_meta)
    fig.savefig(f'4_application.svg', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_svg_meta)
    fig.show()
    # fig.clf()


if __name__ == '__main__':
    main()
