import numpy as np
import os
from matplotlib import pyplot as plt
from functions import filter_pairs

def plot_from_saved_data():
    # Load the data from the npz file
    data = np.load('training_data.npz')
    
    losspr2 = data['losspr2']
    losspr22 = data['losspr22']
    perc2 = data['perc2']
    perc22 = data['perc22']
    lossclick2 = data['lossclick2']
    lossclick22 = data['lossclick22']
    losscost2 = data['losscost2']
    losscost22 = data['losscost22']
    losspr_list = data['losspr_list']
    losscost_list = data['losscost_list']
    lossclick_list = data['lossclick_list']
    perc_list = data['perc_list']

    rollouts = len(losspr2)
    
    for i in range(0, rollouts, 1000):
        if i > 100:
            fig, axs = plt.subplots(1, 3, figsize=(15, 5), dpi=600)  # Adjust subplot size
            for j in range(3):
                idx = i - 100 - (2 - j) * 20
                fixed_point = (np.mean(lossclick2[idx-24:idx]), np.mean(losscost2[idx-24:idx]))
                fixed_point2 = (np.mean(lossclick22[idx-24:idx]), np.mean(losscost22[idx-24:idx]))
                x_point, y_point = filter_pairs(*zip(*[(np.mean(lossclick_list[k][idx-24:idx]), np.mean(losscost_list[k][idx-24:idx])) for k in range(8)]))
                x_point2, y_point2 = filter_pairs(*zip(*[(np.mean(lossclick_list[k][idx-24:idx]), np.mean(losscost_list[k][idx-24:idx])) for k in range(8, 16)]))

                axs[j].plot(x_point, y_point, marker='o', linestyle='-', color='green', label='Regret Net (offline)')
                axs[j].plot(x_point2, y_point2, marker='x', linestyle='-', color='blue', label='Regret Net (online)')
                axs[j].scatter(*fixed_point2, marker='s', color='red', label='AMMD (offline)')
                axs[j].scatter(*fixed_point, marker='p', color='brown', label='AMMD (online)')
                axs[j].set_xlabel('click', size=13)
                axs[j].set_ylabel('cost', size=13)
                axs[j].set_title(f'Traffic {j+1}', size=13)

            handles, labels = axs[-1].get_legend_handles_labels()
            fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, 0.01), framealpha=1, prop={'size': 15})
            plt.tight_layout(rect=[0, 0.1, 1, 1])  # Adjust rect parameters to leave space

            # Save the image to the imgs folder
            if not os.path.exists('imgs'):
                os.makedirs('imgs')
            plt.savefig(f'imgs/dynamic_env_steps{i-3000}_to_{i}.png')
            plt.show()

if __name__ == "__main__":
    plot_from_saved_data()
