import numpy as np
import matplotlib.pyplot as plt
from matplotlib import font_manager

# Locate Times New Roman font explicitly
times_new_roman = font_manager.FontProperties(family="DejaVu Serif")

sum_sngs_values_llama = [
    640.3672, 689.9639, 475.1726, 513.9049, 396.7185, 383.3875, 425.1302, 456.9826, 
    439.0451, 429.9915, 372.2959, 356.3692, 415.5772, 440.6296, 386.2642, 376.9870, 
    293.5426, 209.3110, 159.9887, 103.2894, 93.7213, 99.2580, 87.0160, 86.5179, 
    79.8955, 68.6258, 70.4560, 78.9706, 77.9802, 72.6557, 112.0390, 85.3180
]

sum_sngs_mistral = [
    500.6174, 745.0583, 519.2461, 493.4625, 357.4195, 436.1569, 309.028, 391.3306, 324.1765, 
    329.0556, 384.4009, 344.4065, 333.3658, 341.9063, 336.2464, 353.2995, 275.4776, 240.2239, 
    228.4757, 178.439, 136.4129, 106.5436, 69.9699, 76.9448, 59.616, 51.9169, 56.5729, 57.8368, 
    72.8635, 108.6942, 80.7009, 85.0743
]

data = sum_sngs_values_llama
x = np.arange(len(data))

# Set up the plot
plt.figure(figsize=(10, 2))  # Adjust the figure size as needed
plt.plot(x, data, marker="*", label="Llama-3.1-8B", color="#4682B4", linewidth=2)

# Add axis labels and title
plt.xlabel("Layer Index", fontsize=16, fontproperties=times_new_roman)
plt.ylabel("CCA Bound", fontsize=16, fontproperties=times_new_roman)

plt.xticks(x, fontsize=16, fontproperties=times_new_roman)  # Ensures all x values are displayed
plt.yticks(fontsize=16, fontproperties=times_new_roman)

# Add a grid for better visualization
plt.grid(visible=True, linestyle="--", alpha=0.7)
# Add a legend
plt.legend(fontsize=16)

# Save the figure with high resolution
plt.savefig("1d_array_plot_llama.png", dpi=300, bbox_inches="tight")  # Adjust DPI for resolution

# Close the plot to free memory
plt.close()


data = sum_sngs_mistral
x = np.arange(len(data))

# Set up the plot
plt.figure(figsize=(10, 2))  # Adjust the figure size as needed
plt.plot(x, data, marker="*", label="Mistral-7B", color="orange", linewidth=2)

# Add axis labels and title
plt.xlabel("Layer Index", fontsize=16, fontproperties=times_new_roman)
plt.ylabel("CCA Bound", fontsize=16, fontproperties=times_new_roman)

plt.xticks(x, fontsize=16, fontproperties=times_new_roman)  # Ensures all x values are displayed
plt.yticks(fontsize=16, fontproperties=times_new_roman)

# Add a grid for better visualization
plt.grid(visible=True, linestyle="--", alpha=0.7)
# Add a legend
plt.legend(fontsize=16)

# Save the figure with high resolution
plt.savefig("1d_array_plot_mistral.png", dpi=300, bbox_inches="tight")  # Adjust DPI for resolution

# Close the plot to free memory
plt.close()