#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import argparse
import numpy as np
import matplotlib.pyplot as plt

# 使用tab20颜色映射，提供20种不同的颜色
tab20 = plt.cm.tab20


def set_global_style(font_size: int):
    # 与你参考代码保持一致
    plt.rcParams.update({
        "font.size": font_size,
        "axes.titlesize": font_size + 2,
        "axes.labelsize": font_size,
        "xtick.labelsize": font_size - 1,
        "ytick.labelsize": font_size - 1,
        "legend.fontsize": font_size - 1,
    })


def sanitize_name(s: str) -> str:
    s = s.strip().replace(" ", "_").replace("/", "_")
    s = s.replace("(", "").replace(")", "").replace(",", "_")
    return s


def plot_single_dataset(out_dir: str,
                        dataset_name: str,
                        remain_layers: np.ndarray,
                        acc: np.ndarray,
                        color,
                        dpi: int):
    """
    单数据集折线图：x=remain_layers(降序显示), y=acc，带点
    """
    plt.figure(figsize=(10, 6))
    ax = plt.gca()

    # line + markers (风格与参考代码一致：线宽/点大小/白色描边)
    ax.plot(remain_layers, acc, linewidth=2.8, color=color, label=dataset_name)
    ax.scatter(remain_layers, acc, s=38, color=color, edgecolor="white", linewidth=0.6, zorder=5)

    # 关键：强制 x 轴从大到小显示（31 -> 16）
    # ax.set_xlim(float(np.max(remain_layers)), float(np.min(remain_layers)))
    ax.set_xticks(remain_layers)  # 直接用给定顺序（31..16）
    ax.set_xticklabels([f"{x:.2f}" for x in remain_layers], rotation=0)

    ax.set_xlabel("Pruning Rate")
    ax.set_ylabel("Acc")
    ax.set_title(f"Zero-shot Acc vs Pruning Rate | {dataset_name}")
    ax.legend()

    plt.tight_layout()
    fn = os.path.join(out_dir, f"acc_vs_remain_{sanitize_name(dataset_name)}.png")
    plt.savefig(fn, dpi=dpi)
    plt.close()

def plot_all_datasets(out_dir: str,
                      remain_layers: np.ndarray,
                      data: dict,
                      dpi: int):
    """
    汇总图：所有数据集画在一张图里，tab20配色
    """
    plt.figure(figsize=(10, 6))
    ax = plt.gca()

    tab20_colors = [tab20(i) for i in range(20)]
    keys = list(data.keys())

    for i, ds in enumerate(keys):
        y = np.asarray(data[ds], dtype=np.float32)
        col = tab20_colors[i % 20]

        ax.plot(remain_layers, y, linewidth=6, color=col, label=ds)
        ax.scatter(remain_layers, y, s=80, color=col, edgecolor="white", linewidth=0.3, zorder=5)

    # 关键：强制 x 轴从大到小显示（31 -> 16）
    # ax.set_xlim(float(np.max(remain_layers)), float(np.min(remain_layers)))
    ax.set_xticks(remain_layers)
    ax.set_xticklabels([f"{x:.2f}" for x in remain_layers], rotation=45, fontsize=20)  # 扩大2倍

    # 设置y轴刻度标签字体大小
    ax.tick_params(axis='y', labelsize=20)  # 扩大2倍

    ax.set_xlabel("Pruning Rate", fontsize=22)  # 扩大2倍
    ax.set_ylabel("Acc", fontsize=22)  # 扩大2倍
    ax.set_title("Zero-shot Acc vs Pruning Rate | Llama3-8B", fontsize=22)  # 扩大2倍
    ax.legend(ncol=1, fontsize=14)  # 图例字体也适当增大

    plt.tight_layout()
    fn = os.path.join(out_dir, "acc_vs_remain_ALL.png")
    plt.savefig(fn, dpi=dpi)
    plt.close()

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--out_dir", type=str, required=True)
    ap.add_argument("--font_size", type=int, default=14)
    ap.add_argument("--dpi", type=int, default=250)
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    set_global_style(args.font_size)

    # =========================================================
    # 你的表格数据（保持原样写死在这里，避免读文件歧义）
    # Remain Layer: 31 -> 16（注意：这里就是你想要的顺序）
    # =========================================================
    # remain_layers = np.array([31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16], dtype=np.int32)
    # llama3-8b 
    remain_layers = np.array([0.03, 0.06, 0.09, 0.13, 0.16, 0.19, 0.22, 0.25, 0.28, 0.31, 0.34, 0.38, 0.41, 0.44, 0.47, 0.50])
    data = {
        "arc-challenge": [82.51, 80.55, 67.32, 75.26, 75.77, 65.19, 67.15, 80.03, 79.01, 75.51, 73.38, 29.10, 28.33, 27.73, 22.53, 25.77],
        "hellaswag":     [69.02, 67.72, 53.66, 51.17, 54.97, 54.08, 33.12, 64.20, 72.45, 62.46, 53.91, 26.93, 26.18, 26.31, 25.15, 24.74],
        "arc_easy":      [92.88, 91.33, 77.14, 86.69, 86.57, 73.89, 78.82, 90.61, 88.63, 84.76, 81.22, 30.86, 31.83, 29.52, 27.58, 27.92],
        # "winogrande":    [53.35, 52.57, 49.80, 50.36, 50.20, 49.57, 49.57, 55.56, 58.41, 58.88, 55.72, 50.83, 50.59, 49.72, 48.54, 50.99],
        # "piqa":          [80.30, 79.27, 70.29, 72.96, 70.51, 55.06, 57.78, 80.58, 79.98, 76.55, 55.60, 52.61, 51.58, 51.96, 50.11, 49.78],
        # "boolq":         [82.17, 83.73, 84.07, 84.37, 80.95, 82.20, 83.33, 83.67, 84.10, 82.78, 68.99, 62.60, 62.42, 62.20, 55.23, 46.42],
    }
    # Llama3-8B 
    # remain_layers = np.array([0.03, 0.06, 0.09, 0.13, 0.16, 0.19, 0.22, 0.25, 0.28, 0.31, 0.34, 0.38, 0.41, 0.44, 0.47, 0.50])
    # data = {
    #     "arc-challenge": [55.12, 54.35, 53.92, 53.24, 53.24, 53.92, 52.13, 53.16, 55.29, 55.63, 54.10, 40.61, 25.60, 25.00, 26.11, 24.74],
    #     "hellaswag":     [45.19, 43.99, 38.77, 36.68, 33.90, 33.22, 29.06, 31.04, 39.55, 35.14, 33.90, 27.29, 25.63, 25.01, 24.56, 25.39],
    #     "arc_easy":      [72.63, 72.00, 71.16, 69.98, 70.27, 71.07, 68.51, 70.19, 73.52, 73.09, 73.43, 55.49, 33.68, 28.51, 26.53, 25.89],
    # }

    # qwen3-4b 
    # remain_layers = np.array([0.03, 0.06, 0.08, 0.11, 0.14, 0.17, 0.19, 0.22, 0.25, 0.28, 0.31, 0.33, 0.36, 0.39, 0.42, 0.44, 0.47, 0.50, 0.53, 0.56])
    # data = {
    #     "arc-challenge": [88.82, 88.31, 88.23, 87.88, 83.02, 81.57, 57.59, 53.24, 52.99, 44.88, 22.44, 23.12, 24.15, 22.18, 22.87, 22.61, 23.04, 22.87, 22.35, 24.06 ],
    #     "hellaswag":     [82.98, 81.75, 81.48, 79.77, 69.68, 69.27, 34.51, 35.23, 35.88, 32.85, 25.12, 24.82, 25.00, 25.01, 25.05, 25.13, 25.02, 25.14, 24.27, 25.24],
    #     "arc_easy":      [96.29, 96.25, 95.03, 94.61, 87.37, 86.69, 73.81, 66.44, 64.55, 54.61, 23.92, 24.97, 25.01, 25.22, 25.39, 25.18, 25.26, 25.05, 24.21, 25.43],
    # }

    # 基本校验
    for k, v in data.items():
        if len(v) != len(remain_layers):
            raise ValueError(f"Length mismatch for {k}: got {len(v)}, expect {len(remain_layers)}")

    # =========================================================
    # 逐数据集绘图（每个数据集一张）
    # 颜色同样用 tab20，保持可复用
    # =========================================================
    tab20_colors = [tab20(i) for i in range(20)]
    # for i, (ds, acc_list) in enumerate(data.items()):
    #     color = tab20_colors[i % 20]
    #     plot_single_dataset(
    #         out_dir=args.out_dir,
    #         dataset_name=ds,
    #         remain_layers=remain_layers,
    #         acc=np.asarray(acc_list, dtype=np.float32),
    #         color=color,
    #         dpi=args.dpi
    #     )

    # =========================================================
    # 汇总图（所有数据集一张）
    # =========================================================
    plot_all_datasets(
        out_dir=args.out_dir,
        remain_layers=remain_layers,
        data=data,
        dpi=args.dpi
    )

    print("[Done] Saved plots to:", args.out_dir)


if __name__ == "__main__":
    main()
