import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

file_paths = [
    'results/005-SiT-B-2-256/metrics.csv',
    'results/007-SiT-B-2-1024/metrics.csv',   
    'results/006-SiT-B-2-4096/metrics.csv',  
]
labels = ['256', '1024', '4096']  

fig, ax1 = plt.subplots(figsize=(12, 7))
fontsize = 14

memory_metrics = 'clipiqa'  #  'musiq' or 'clipiqa'

colors_left = ['tab:blue', 'tab:green', 'tab:purple']
colors_right = ['tab:red', 'tab:orange', 'tab:brown']

ax2 = ax1.twinx()

for i, (file_path, label) in enumerate(zip(file_paths, labels)):
    data = pd.read_csv(file_path)

    color = colors_left[i]
    ax1.plot(data['train_step'], data[memory_metrics], color=color, 
             label=f'{label} {memory_metrics}', linestyle='-')
    
    color = colors_right[i]
    ax2.plot(data['train_step'], data['memory'], color=color, 
             label=f'{label} memory', linestyle='--')

ax1.set_xlabel('Train Step', fontsize=fontsize)
ax1.set_ylabel(memory_metrics, color='black', fontsize=fontsize)
ax1.tick_params(axis='y')

ax2.set_ylabel('Memory', color='black', fontsize=fontsize)
ax2.set_ylim(0, 0.8) 
ax2.set_yticks(np.linspace(0, 0.8, 5))  
ax2.set_ylim(bottom=-0.05) 


lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left', fontsize=fontsize-2)


plt.tight_layout()
output_path = 'plot/quality_memory_comparison.png'
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()