import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set seaborn's default style for improved aesthetics
sns.set()

# Load the CSV file
csv_file_path = './save/medinit_mass.csv'
data = pd.read_csv(csv_file_path)

# Automatically detect all iteration columns (e.g., 'iter_1', 'iter_2', 'iter_3', ...)
iteration_columns = [col for col in data.columns if col.startswith('iter_')]

# Compute the mean and standard deviation across all iterations
data['mean_rank'] = data[iteration_columns].mean(axis=1)
data['std_rank'] = data[iteration_columns].std(axis=1)


# Plot the data
plt.figure(figsize=(10, 7))

# Plot the mean line in red with larger marker size
plt.plot(data['Layer'], data['mean_rank'], label='GPT2 medium(layers=24)', marker='o', markersize=10, linestyle='-', linewidth=3)

# Plot the shaded area representing the standard deviation in a lighter red
plt.fill_between(data['Layer'],
                 data['mean_rank'] - data['std_rank'],
                 data['mean_rank'] + data['std_rank'],
                 alpha=0.2)

# Adding a title and adjusting font sizes
plt.xlabel('# Layers', fontsize=20)  # Increased font size for X-axis label
plt.ylabel('Avg. # columns with 90% total mass', fontsize=20)  # Increased font size for Y-axis label
# plt.legend()
# plt.legend(fontsize=20, loc='upper right')

# Beautifying the plot with grid lines
plt.grid(True, linestyle='--', alpha=0.5)

# Increase the size of the axis ticks
plt.tick_params(axis='both', which='major', labelsize=22)  # Increased tick size

# Make sure the layout is tight
plt.tight_layout()

# Save the plot as a PDF
pdf_filename = './Figures/medium_init_mass_analysis.pdf'
plt.savefig(pdf_filename, format='pdf', bbox_inches="tight", dpi=600)

# Show the plot
plt.show()
plt.close()
