import json
import matplotlib.pyplot as plt
import numpy as np

metric = {}

metric["tao"] = np.array([0.001 ,0.005 ,0.007 ,0.01 ,0.03 ,0.05 ,0.1 ,0.2])
metric["neglog_tao"] = -np.log(metric["tao"])

metric["rouge_raw"] =      np.array([0.986, 0.982, 0.407, 0.747]) * 100
metric["rouge_baseline"] = np.array([0.614, 0.916, 0.415, 0.736]) * 100
metric["rouge_tao"] =      np.array([[0.867,0.940,0.382,0.765],
                                    [0.867,0.940,0.382,0.765],
                                    [0.660,0.956,0.426,0.790],
                                    [0.629,0.963,0.436,0.785],
                                    [0.597,0.950,0.429,0.764],
                                    [0.601,0.938,0.437,0.764],
                                    [0.603,0.933,0.437,0.762],
                                    [0.607,0.929,0.452,0.753]]) * 100

for i in ["raw", "baseline", "tao"]:
    metric["rougedrop_" + i] =  metric["rouge_raw"] - metric["rouge_" + i]

metric["scrfc_rate_baseline"] = metric["rougedrop_baseline"]/metric["rougedrop_baseline"][0] * 100
metric["scrfc_rate_tao"] = metric["rougedrop_tao"]/metric["rougedrop_tao"][:,0][:, None] * 100
metric["scrfc_rate_raw"] = None

set_order = {"Forget Set":0,
             "Retain Set":1,
             "Real Author Set":2,
             "Real World Facts Set":3}

order_set = {}
for k,v in set_order.items():
    order_set[v] = k

row_eval_matrix = {0:"rouge",
                   1:"rougedrop",
                   2:"scrfc_rate"}

abbre_legend = {"rouge": "ROUGE-1 Score",
                "rougedrop": r"ROUGE-1 Drop $\Delta_{ROUGE-1}$" ,
                "scrfc_rate": r"Sacrifice Rate $\rho^{ul}$"}

fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(12, 8))
for eval_id in range(3):
    for set_id in range(4):
    # row stands for eval_matrix, col stands for set
        ax = axes[eval_id][set_id]
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.grid(True, linestyle='--',linewidth = 1.5, c="white", alpha=0.7)
        ax.set_facecolor("silver") 
        plt.tight_layout(pad=1)
        # plt.subplots_adjust(top=0.8)       
        eval_matrix = row_eval_matrix[eval_id]
        x = metric["neglog_tao"][4:]
        y_guard = metric[eval_matrix + "_tao"][4:,set_id]
        y_baseline = metric[eval_matrix + "_baseline"][set_id]
        # ax.plot(x, y_guard,  c = "maroon", label = f"GUARD")
        ax.scatter(x, y_guard, s=40, c = "maroon", label = f"GUARD")
        # ax.scatter(x, [y_baseline]*len(x), s=20, c = "red")
        ax.axhline(y= y_baseline, xmin = 0, xmax=1, c= "royalblue", linewidth = 2.5, linestyle = "--", label = f"Gradient Ascent")
        if metric[eval_matrix + "_raw"] is not None:
            y_raw = metric[eval_matrix + "_raw"][set_id]
            ax.axhline(y= y_raw, c= "teal", linewidth = 2.5, linestyle = "--", label = f"w.o. Unlearning")
        
        if eval_id == 0 and set_id ==3:
            ax.legend(fontsize = 10, loc = "center right")
        elif eval_id == 0 and set_id ==2:
            ax.legend(fontsize = 10, loc = "lower left")
        else:
            ax.legend(fontsize = 10)
        # ax.set_xticks([1, 2, 3])
        # ax.set_xticklabels(['One', 'Two', 'Three'], fontsize=12)
        # ax.tick_params(axis='x', labelsize=12)
        # ax.tick_params(axis='y', labelsize=12)
        
        if eval_id== 0:
            ax.set_title(order_set[set_id], fontsize = 15)
        if eval_id== 2:
            ax.set_xlabel(r"$\tau$ for unification", fontsize = 13)
        if set_id== 0:
            ax.set_ylabel(abbre_legend[row_eval_matrix[eval_id]], fontsize =12)        

fig.suptitle(f"Unlearn Phi-1.5B on TOFU (Forget Split = 1%)", fontsize=16, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 1])  # Adjust layout to fit suptitle     
plt.savefig("SR.jpg")
plt.clf()

