
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import pandas as pd
import json
import ast
def plot_F1_val_during_training_oneplot(model,task):
    fontsize=35

    reconst_resultdf = pd.read_csv(f"Reconst_Val_F1_during_training_{model}_{task}.csv",index_col=0)
    hook = reconst_resultdf['hook_result']

    f1 = [ast.literal_eval(x)['Best F1'] for x in hook]
    locf1 = [ast.literal_eval(x)['Localization Best F1'] for x in hook]
    reconst_resultdf['F1']=f1

    pred_resultdf = pd.read_csv(f"Pred_Val_F1_during_training_{model}_{task}.csv",index_col=0)
    hook = pred_resultdf['hook_result']

    f1 = [ast.literal_eval(x)['Best F1'] for x in hook]
    locf1 = [ast.literal_eval(x)['Localization Best F1'] for x in hook]
    pred_resultdf['F1']=f1

    fig, ax1 = plt.subplots(figsize=(20,10))
    
    ax2 = ax1.twinx() 
    
    ax1.plot(reconst_resultdf['F1'],'b', linewidth=3)
    ax1.plot(pred_resultdf['F1'],'g', linewidth=3)
    ax1.set_xlabel("Iterations", fontname="Arial", fontsize=fontsize)
    ax1.set_ylabel("F1 score", fontname="Arial", fontsize=fontsize)
    
    
    ax1.set_title(f"F1 score and validation loss on {task} test set during training.", fontname="Arial", fontsize=fontsize)
    ##############
    ax2.plot(reconst_resultdf['prediction_val_loss'],'b--', linewidth=3)
    ax2.plot(pred_resultdf['prediction_val_loss'],'g--', linewidth=3)
    ax2.set_ylabel("Error", fontname="Arial", fontsize=fontsize)
    box = ax1.get_position()
    ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax1.legend(["Reconstruction F1","Masked Predictive F1"], loc='upper left', bbox_to_anchor=(1.1, 1))
    ax2.legend(["Reconstruction Val. Loss","Masked Predictive Val. Loss"], loc='upper left', bbox_to_anchor=(1.1, 0.8))
    #ax2.legend(["Reconstructing","Predictive"])
    #plt.title("Validation loss on SWaT during training.")
    
    fig.savefig(f"Reconstructing_predictive_model_valloss_during_training_merged_{model}_{task}.jpg",dpi=500,bbox_inches="tight",pad_inches=0.5)


if __name__=='__main__':
    font = {'family' : 'Arial',
        #'weight' : 'bold',
        'size'   : 35}

    matplotlib.rc('font', **font)

    plot_F1_val_during_training_oneplot('transformer','swat')