import matplotlib.pyplot as plt
import json
import numpy as np

# # ##############################################
# # # plot raw score: gnorm vs gprod
# # ##############################################
def plot_gnorm_vs_gprod_forget01():
        raw_path_gprod = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_prod_exp_influence_dict_original.json"
        raw_path_gnorm = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_norm_exp_influence_dict_original.json"

        with open(raw_path_gprod, 'r') as file:
                raw_gprod = list(json.load(file).values())
        
        with open(raw_path_gnorm, 'r') as file:
                raw_gnorm = list(json.load(file).values())
                
        sorted_indices_prod = sorted(range(len(raw_gprod)), key=lambda i: raw_gprod[i], reverse=True)
        raw_gnorm = [raw_gnorm[i] for i in sorted_indices_prod]
        raw_gprod = [raw_gprod[i] for i in sorted_indices_prod]

        fig, ax = plt.subplots(figsize=(4,4))
        for spine in ax.spines.values():
                spine.set_visible(False)
        ax.grid(True, linestyle='--',linewidth = 1.5, c="white", alpha=0.7)
        ax.set_facecolor("silver") 
        plt.tight_layout(pad=3)
        plt.subplots_adjust(top=0.8)

        ax.scatter(range(1,41), raw_gnorm, s=10, c = "darkblue",label= "Gradient Norm")
        ax.scatter(range(1,41), raw_gprod, s=10, c="brown", label= "Gradient Product")

        ax.set_title("Attribution Raw Score:\n Gradient Product vs Gradient Norm", fontsize = 11)
        ax.set_xlabel("Index of data to forget \n (ordered by score of gradient product)", fontsize = 11)
        ax.set_ylabel("Raw Score", fontsize = 11)   

        plt.legend(fontsize = 9, loc = "upper left")
        plt.savefig("raw_score.jpg")
        plt.clf() 






# # ##############################################
# # # plot tao-based power unification (g_prod)
# # ##############################################


def plot_tao_power(tao_list, with_raw, save_dir):
        fig, ax = plt.subplots(figsize=(4,4))
        for spine in ax.spines.values():
                spine.set_visible(False)
        ax.grid(True, linestyle='--',linewidth = 1.5, c="white", alpha=0.7)
        ax.set_facecolor("silver") 
        plt.tight_layout(pad=3)
        plt.subplots_adjust(top=0.8)
        
        
        
        c_list = ["yellow", "limegreen", "blue", "purple"]
        i_c = 0
        s=30
        for t in tao_list:
                score_path = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_prod_power_shiftt" + str(t) + "_influence_dict.json"
                with open(score_path, 'r') as file:
                        score = list(json.load(file).values())
                        score = np.array([score[i] for i in sorted_indices])
                        # ax.scatter(range(1,41), score, c = "r", marker = "D",s= s, alpha = a, label= "power_t="+str(t))
                        ax.plot(range(1,41), score, linewidth= 1.8, c = c_list[i_c], label= rf"$\tau$={t}")
                        s-=9
                        i_c += 1
                        
                        
        
        if with_raw:
                raw_path = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_prod_power_influence_dict_original.json"
                with open(raw_path, 'r') as file:
                        raw = list(json.load(file).values())
                        raw = [raw[i] for i in sorted_indices]
                # ax.plot(range(1,41), raw, linestyle = "--", c="darkblue", label= "raw")
                ax.scatter(range(1,41), raw, c = "brown", marker = "o",s= 7, label= "raw")
        
        if with_raw:
                title = f"Reversely Unified Score v.s. Raw Score\n Attribution: Gradient Product\n Unification: Power"
        else:
                title = f"Attribution_Score (Reversely Unified): gprod + power"
        
        ax.set_title(title, fontsize = 10)
        ax.set_xlabel("Ordered Index of data to forget", fontsize = 10)
        ax.set_ylabel("Multiplier Score", fontsize = 10)   
        
        plt.legend(fontsize = 8, loc = "upper left")
        plt.savefig(save_dir)
        plt.clf() 








# # ##############################################
# # # plot tao-based exp unification (g_prod)
# # ############################################## 

def plot_tao_exp(tao_list, with_raw, save_dir):
        fig, ax = plt.subplots(figsize=(4,4))
        for spine in ax.spines.values():
                spine.set_visible(False)
        ax.grid(True, linestyle='--',linewidth = 1.5, c="white", alpha=0.7)
        ax.set_facecolor("silver") 
        plt.tight_layout(pad=3)
        plt.subplots_adjust(top=0.8)
        
        
        w=4
        a=0.4
        s=30
        for t in tao_list:
                score_path = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_prod_expt" + str(t) + "_influence_dict.json"
                with open(score_path, 'r') as file:
                        score = list(json.load(file).values())
                        score = np.array([score[i] for i in sorted_indices])
                        # ax.scatter(range(1,41), score, c = "r", marker = "D",s= s, alpha = a, label= "exp_t="+str(t))
                        ax.plot(range(1,41), score, linewidth = w, alpha = a, c= "navy",label= rf"$\tau$={t}")
                        s-=9
                        w -= 0.75
                        a += 0.15
        
        if with_raw:
                raw_path = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_prod_exp_influence_dict_original.json"
                with open(raw_path, 'r') as file:
                        raw = list(json.load(file).values())
                        raw = [raw[i] for i in sorted_indices]
                # ax.plot(range(1,41), raw, linestyle = "--", c="darkblue", label= "raw")
                ax.scatter(range(1,41), raw, c = "brown", marker = "o",s= 7, label= "raw")
        
        if with_raw:
                title = f"Reversely Unified Score v.s. Raw Score\n Attribution: Gradient Product\n Unification: Exponential"
        else:
                title = f"Attribution_Score (Reversely Unified): gprod + exp"
        
        ax.set_title(title, fontsize = 10)
        ax.set_xlabel("Ordered Index of data to forget", fontsize = 10)
        ax.set_ylabel("Multiplier Score", fontsize = 10)   
        
        plt.legend(fontsize = 7, loc = "upper left")
        plt.savefig(save_dir)
        plt.clf() 



if __name__ == "__main__":
        plot_gnorm_vs_gprod_forget01()
        
        raw_path = "checkpoints/ft_epoch5_lr1e-05_phi_full_wd0.01/checkpoint-5000/forget01_g_prod_exp_influence_dict_original.json"
        with open(raw_path, 'r') as file:
                raw = list(json.load(file).values())
        sorted_indices = sorted(range(len(raw)), key=lambda i: raw[i], reverse=True)
        
        plot_tao_power(tao_list = [0.001,.03,.06,.1][::-1],
                        # tao_list = [0.001,0.01,0.05],
                        with_raw = True, 
                        save_dir = "tao_gprod_power2.jpg")
        
        plot_tao_exp(tao_list = [0.03,0.05,0.1, 0.2,  5],
                        with_raw = True, 
                        save_dir = "tao_gprod_exp2.jpg")
