import matplotlib.pyplot as plt
import pandas as pd



file_paths = ['results/SingleRun_adaptive.csv', 'results/SingleRun_fixed_mu.csv', 'results/SingleRun_origin.csv', 'results/SingleRun_centralized.csv']
dfs = [pd.read_csv(file_path) for file_path in file_paths]


cases_list = [df['case'].unique() for df in dfs]


#labels = ['adaptive μ', 'fixed μ', 'original version']
labels = ['Original version', 'Fixed μ', 'Adaptive μ','Centralized']


plt.figure(figsize=(10, 5))
for i, df in enumerate(dfs):

    if df['case'].isnull().all():
        case_data = df
        plt.plot(case_data['tValue'], case_data['lossValue'], label=f'{labels[i]}')
    else:
        for j, case in enumerate(cases_list[i]):
            case_data = df[df['case'] == case]
            plt.plot(case_data['tValue'], case_data['lossValue'], label=f'{labels[i]} - Case {j + 1}')

plt.xlabel('Time(s)')
plt.ylabel('Loss Value')

# plt.title('Loss Value Over Time for Different Cases and Runs')
plt.legend()
plt.grid(True)


plt.savefig('loss_plot.png', dpi=300)
plt.show()


plt.figure(figsize=(10, 5))
for i, df in enumerate(dfs):

    if df['case'].isnull().all():
        case_data = df
        plt.plot(case_data['tValue'], case_data['predictionAccuracy'], label=f'{labels[i]}')
    else:
        for j, case in enumerate(cases_list[i]):
            case_data = df[df['case'] == case]
            plt.plot(case_data['tValue'], case_data['predictionAccuracy'], label=f'{labels[i]} - Case {j + 1}')

plt.xlabel('Time(s)')
plt.ylabel('Prediction Accuracy')

# plt.title('Prediction Accuracy Over Time for Different Cases and Runs')
plt.legend()
plt.grid(True)


plt.savefig('accuracy_plot.png', dpi=300)
plt.show()
