import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

# Define parameters (should match the data generation script)
n = 400
K = 4
s_n_values = [1, 2, 3, 5, 10]

# Define the base directory for data files relative to the script's location
# If this script is in experiments/scripts/, then ../dat/ points to experiments/dat/
data_dir = "../dat"
img_dir = "../img"

# Ensure the image directory exists
os.makedirs(img_dir, exist_ok=True)

# plt.style.use('seaborn-v0_8-darkgrid') # Using a style for better aesthetics
plt.figure(figsize=(15, 7)) # Increased figure size for 3 plots

# Plot for potential
plt.subplot(1, 2, 1)
for s in s_n_values:
    filename = os.path.join(data_dir, f"potentials_s{s}_n{n}_K{K}.csv")
    if os.path.exists(filename):
        df = pd.read_csv(filename)
        plt.plot(df["iteration"], df["potential"], label=f"$s_n={s}$")
    else:
        print(f"Warning: File not found for s_n={s}: {filename}")

plt.xlabel("Iteration (t)", fontsize=12)
plt.ylabel("$\langle 1, x \\rangle^2 / \Vert x \Vert^2$", fontsize=12)
# plt.title("Alignment with u_1 (Global Mean)", fontsize=14)
plt.legend(title="Negative Samples (s_n)", loc="best")
plt.grid(True, linestyle='--', alpha=0.6)

# Plot for community potential
plt.subplot(1, 2, 2)
for s in s_n_values:
    filename = os.path.join(data_dir, f"potentials_s{s}_n{n}_K{K}.csv")
    if os.path.exists(filename):
        df = pd.read_csv(filename)
        plt.plot(df["iteration"], df["community_potential"], label=f"$s_n={s}$")
    else:
        print(f"Warning: File not found for s_n={s}: {filename}")

plt.xlabel("Iteration (t)", fontsize=12)
plt.ylabel("$\Vert \Pi x \Vert^2 / \Vert x \Vert^2$", fontsize=12)
# plt.title("Alignment with Community Vectors", fontsize=14)
plt.legend(title="Negative Samples (s_n)", loc="best")
plt.grid(True, linestyle='--', alpha=0.6)

# # Plot for column orthogonality potential
# plt.subplot(1, 3, 3)
# for s in s_n_values:
#     filename = os.path.join(data_dir, f"potentials_s{s}_n{n}_K{K}.csv")
#     if os.path.exists(filename):
#         df = pd.read_csv(filename)
#         plt.plot(df["iteration"], df["col_orth_potential"], label=f"$s_n={s}$")
#     else:
#         print(f"Warning: File not found for s_n={s}: {filename}")

# plt.xlabel("Iteration (t)", fontsize=12)
# plt.ylabel("Column Orthogonality Potential", fontsize=12)
# plt.title("Orthogonality of Embedding Dimensions", fontsize=14)
# plt.legend(title="Negative Samples (s_n)", loc="best")
# plt.grid(True, linestyle='--', alpha=0.6)

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make space for suptitle
plt.suptitle(f"Potentials for n={n}, K={K}", y=0.98, fontsize=16, weight='bold')

# Save the plot
output_filename = os.path.join(img_dir, f"potentials_plot_n{n}_K{K}.eps")
plt.savefig(output_filename, format='eps', bbox_inches='tight')
print(f"Plot saved to {output_filename}")

# plt.show() # Uncomment to display the plot interactively
