import matplotlib.pyplot as plt
import numpy as np

import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()

# =========================
# Data
# =========================
pruned = list(range(1, 19)) # 17 for llama

# 原 dense_argmax → 重命名
################################################# llama2-7b #################################################
# hellaswag decision_dense_argmax = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 12, 12, 10, 10, 8] hidden_dense_argmax = [1,2,3,4,5,6,7,8,9,10,10,14,15,15,30,1]
# arc_challenge decision_dense_argmax = [1,2,3,4,5,6,7,8,9,10,10,11,11,11,10,10] hidden_dense_argmax = [1,2,3,4,5,6,7,8,9,10,10,14,15,15,31,30]
# arc_easy decision_dense_argmax = [1,2,3,4,5,6,7,8,9,10,10,11,11,10,10,10] hidden_dense_argmax = [1,2,3,4,5,6,7,8,9,10,10,12,15,15,31,30]
#############################################################################################################
################################################# llama3-8b #################################################
# hellaswag
# decision_dense_argmax =[1,2,3,4,4,4,4,5,5,5,4,5,5,5,5,5]
# hidden_dense_argmax = [1,2,3,4,4,5,6,7,8,9,9,14,15,22,21,22]
# arc_challenge
# decision_dense_argmax =[1,2,2,4,4,4,7,7,7,7,9,9,9,9,9,9]
# hidden_dense_argmax =[1,2,3,4,4,5,6,7,8,9,9,10,15,22,22,24]
# arc_easy
# decision_dense_argmax =[1,2,2,4,4,4,5,5,7,9,9,9,9,9,9,9]
# hidden_dense_argmax =[1,2,3,4,5,5,6,7,8,9,9,10,15,22,22,25]
#############################################################################################################
################################################# qwen3-4b #################################################
# arc_challenge
# decision_dense_argmax = [1,3,3,3,4,4,3,3,3,3,3,3,3,3,3,5,5,5]
# hidden_dense_argmax = [1,1,1,1,15,21,21,22,22,21,22,22,22,24,25,1,1,1]
# arc_easy
# decision_dense_argmax = [1,1,2,2,2,2,2,2,2,2,2,5,5,5,6,5,5,5]
# hidden_dense_argmax = [1,1,1,1,12,12,12,12,12,12,12,16,15,15,25,1,1,1]
# hellaswag
decision_dense_argmax = [1,3,3,3,4,4,3,3,3,3,3,3,3,3,3,5,5,5]
hidden_dense_argmax = [1,1,1,1,23,23,16,23,23,23,23,23,23,23,24,28,28,28]



# =========================
# Colors
# =========================
tab20 = plt.cm.tab20
color_decision = tab20(0)   # 蓝系
color_hidden   = tab20(2)   # 橙系

# =========================
# Plot
# =========================
fig, ax = plt.subplots(figsize=(10, 6))

# 对角线：Perfect Alignment
ax.plot(
    [1, 18], [1, 18],
    color='gray',
    linestyle='--',
    alpha=0.5,
    linewidth=2.0,             # ⬅️ 对角线更清晰
    label='Alignment'
)

# decision_dense_argmax
ax.scatter(
    pruned,
    decision_dense_argmax,
    color=color_decision,
    s=200,
    label='Decision Argmax',
    zorder=3
)
ax.plot(
    pruned,
    decision_dense_argmax,
    color=color_decision,
    alpha=0.3,
    linewidth=3
)

# hidden_dense_argmax
ax.scatter(
    pruned,
    hidden_dense_argmax,
    color=color_hidden,
    s=200,
    label='Hidden Argmax',
    zorder=3
)
ax.plot(
    pruned,
    hidden_dense_argmax,
    color=color_hidden,
    alpha=0.3,
    linewidth=3
)

# =========================
# Formatting
# =========================
ax.set_xlabel('Pruned LLM Layer ID', fontsize=22)
ax.set_ylabel('Dense LLM Layer ID', fontsize=22)
ax.set_title('Qwen3-4B | Hellaswag', fontsize=22)

ax.set_xticks(range(1, 19))
ax.set_yticks(range(1, 37, 2))

ax.tick_params(axis='x', labelsize=20)
ax.tick_params(axis='y', labelsize=20)

ax.grid(True, which='both', linestyle=':', alpha=0.5)

ax.legend(ncol=1, fontsize=14)

plt.tight_layout()
os.makedirs(args.output_dir, exist_ok=True)
plt.savefig(os.path.join(args.output_dir, "dense_argmax_alignment.png"), dpi=300)
