import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams.update({
    "font.size": 16,         # base size for all text
    "axes.labelsize": 18,    # x/y labels
    "axes.titlesize": 18,    # axes/figure titles
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
})

# ------------------ Recall ------------------ #
base_recall = [77.16, 77.45, 76.66, 76.00, 74.92, 73.09, 70.36, 65.66, 60.11, 50.83, 39.46, 28.88, 18.25, 11.47, 6.75, 3.75, 2.01, 1.16, 0.73, 0.43]
base_recall_scale = [0.00, 0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75]

no_fzr_recall =[77.61, 76.96, 77.56, 76.87, 76.08, 74.99, 72.53, 68.74, 62.97, 53.85, 44.44, 34.42, 24.13, 16.58, 11.42, 7.15, 4.17, 2.37, 1.19, 0.61]
no_fzr_scale= [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5]

full_stack_recall = [78.41, 78.40, 78.98, 78.93, 78.86, 79.12, 78.72, 78.95, 78.69, 78.64, 78.38, 78.05, 77.72, 77.29, 75.47, 73.52, 70.18, 63.77, 55.96, 45.05]
full_stack_recall_scale = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0]

plt.figure()
plt.plot(base_recall_scale, base_recall, linestyle=(0, (1,1,1,1,4,1)), label="Base")
plt.plot(no_fzr_scale, no_fzr_recall, linestyle='--', label="No Fuzzy Recall")
plt.plot(full_stack_recall_scale, full_stack_recall, label="Full Stack")
plt.xlabel("Recall Strength")
plt.ylabel("BLEU Score")
plt.xlim(0, max(full_stack_recall_scale))
plt.ylim(bottom=50)
plt.legend()
plt.tight_layout()
out1 = "/mnt/data/recall_vs_scale_v2.png"
#plt.savefig(out1, dpi=150, bbox_inches="tight")
plt.show()

# ------------------ Blur ------------------ #
base_blur = [77.57, 77.16, 75.17, 68.94, 44.48, 4.01, 0.12, 0.06, 0.08, 0.13, 0.14, 0.14, 0.13, 0.13, 0.12, 0.13, 0.13, 0.12, 0.13, 0.12]
base_blur_scale = [0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95]

no_fzr_blur = [77.41, 77.54, 77.47, 77.81, 77.53, 77.97, 77.65, 77.70, 77.67, 77.40, 77.46, 77.54, 76.93, 76.59, 76.37, 76.45, 76.21, 75.79, 75.11, 74.19]
no_fzr_blur_scale = [0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95]

full_stack_blur = [78.51, 78.56, 78.25, 78.10, 78.23, 77.98, 78.10, 78.14, 77.77, 77.77, 78.12, 77.71, 77.97, 77.41, 77.48, 77.56, 77.42, 76.95, 77.17, 77.19]
full_stack_blur_scale = [0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95]

plt.figure()
plt.plot(base_blur_scale, base_blur, linestyle=(0, (1,1,1,1,4,1)), label="Base")
plt.plot(no_fzr_blur_scale, no_fzr_blur, linestyle='--', label="No Fuzzy Recall")
plt.plot(full_stack_blur_scale, full_stack_blur, label="Full Stack")
plt.xlabel("Blur Strength")
plt.ylabel("BLEU Score")
plt.xlim(0, 1.0)
plt.ylim(bottom=65)
plt.legend()
plt.tight_layout()
out2 = "/mnt/data/blur_vs_scale_v2.png"
#plt.savefig(out2, dpi=150, bbox_inches="tight")
plt.show()

# ------------------ Noise ------------------ #
base_noise = [77.32, 77.41, 77.31, 76.96, 76.64, 76.68, 75.78, 74.34, 73.21, 70.88, 66.07, 59.62, 49.67, 36.63, 22.85, 12.38, 5.93, 2.52, 1.09, 0.49]
base_noise_scale = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]

no_fzr_noise = [77.45, 77.49, 77.37, 77.32, 77.21, 77.04, 77.04, 77.34, 77.03, 76.82, 76.72, 76.00, 75.07, 73.50, 70.40, 65.38, 57.71, 47.75, 36.38, 25.99]
no_fzr_noise_scale = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]

full_stack_noise = [78.16, 78.03, 78.25, 78.59, 78.18, 78.50, 78.64, 78.80, 78.98, 78.91, 79.20, 78.88, 79.28, 78.76, 78.25, 77.53, 76.15, 72.86, 63.66, 46.89]
full_stack_noise_scale = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2, 3.4, 3.6, 3.8]

plt.figure()
plt.plot(base_noise_scale, base_noise, linestyle=(0, (1,1,1,1,4,1)), label="Base")
plt.plot(no_fzr_noise_scale, no_fzr_noise, linestyle='--', label="No Fuzzy Recall")
plt.plot(full_stack_noise_scale, full_stack_noise, label="Full Stack")
plt.xlabel("Noise Strength")
plt.ylabel("BLEU Score")
plt.xlim(0, max(full_stack_noise_scale))
plt.ylim(bottom=50)
plt.legend()
plt.tight_layout()
out3 = "/mnt/data/noise_vs_scale_v2.png"
#plt.savefig(out3, dpi=150, bbox_inches="tight")
plt.show()

(out1, out2, out3)







#generate concept pair graph
layers = list(range(24))

big_bit = [
    0.626, 0.551, 0.578, 0.530, 0.481, 0.417, 0.398, 0.369,
    0.333, 0.258, 0.230, 0.246, 0.268, 0.266, 0.263, 0.249,
    0.247, 0.246, 0.237, 0.238, 0.246, 0.240, 0.241, 0.222
]

map_food = [
    0.688, 0.610, 0.590, 0.542, 0.495, 0.409, 0.379, 0.373,
    0.453, 0.673, 0.627, 0.590, 0.617, 0.601, 0.600, 0.562,
    0.558, 0.561, 0.548, 0.547, 0.562, 0.558, 0.553, 0.564
]

him_her = [
    0.491, 0.424, 0.405, 0.352, 0.326, 0.273, 0.243, 0.210,
    0.160, 0.167, 0.134, 0.140, 0.144, 0.146, 0.147, 0.142,
    0.146, 0.144, 0.141, 0.142, 0.140, 0.135, 0.137, 0.138
]

plt.figure()
plt.plot(layers, big_bit, linestyle=(0, (1,1,1,1,4,1)), label="Big–Bit")
plt.plot(layers, map_food, linestyle='--', label="Map–Food")
plt.plot(layers, him_her, label="Him–Her")
plt.xlabel("Decoder layer")
plt.ylabel("Proportional Difference in Overlap")
plt.legend()
plt.tight_layout()

plt.show()

# Save a copy for download
#out_path = "/FinalFigures/overlap_by_layer.png"
#plt.savefig(out_path, dpi=150, bbox_inches="tight")

