import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

def plot_grads_vs_delta_z(dataset, model_name_or_path, radius):
    file_path = f"results/data_results/turb/{model_name_or_path}/d-{dataset}_results.jsonl"
    with open(file_path, "r", encoding="utf-8") as f:
        results = [json.loads(line) for line in f]
    g_color = "#0D4C6D"
    z_color = "#FEB705"
    gz_color = "#BF1E2E"

    radius_results = [item for item in results if item.get("radius") == radius]

    delta_log_prob_list = np.concatenate([item["delta_log_prob_list"] for item in radius_results]) # 1200
    linear_approx_list = np.concatenate([item["linear_approx_list"] for item in radius_results]) # 1200, 25
    loss_list = np.concatenate([item["loss_list"] for item in radius_results]) # 1200, 25
    delta_z_list = np.concatenate([item["delta_z_list"] for item in radius_results]) # 1200, 25
    grads_list = np.concatenate([item["grads_list"] for item in radius_results]) # 1200, 25

    delta_log_prob_mean = np.mean(delta_log_prob_list, axis=0)
    delta_log_prob_std = np.std(delta_log_prob_list, axis=0)
    
    linear_approx_mean = np.mean(linear_approx_list, axis=0)
    linear_approx_std = np.std(linear_approx_list, axis=0)
    loss_mean = np.mean(loss_list, axis=0)
    loss_std = np.std(loss_list, axis=0)
    delta_z_mean = np.mean(delta_z_list, axis=0)
    delta_z_std = np.std(delta_z_list, axis=0)
    grads_mean = np.mean(grads_list, axis=0)
    grads_std = np.std(grads_list, axis=0)


    abs_linear_approx_mean = delta_z_mean * grads_mean

    x = np.arange(len(linear_approx_mean))
    
    plt.figure(figsize=(3, 3))


    plt.plot(x, 
                grads_mean, 
                label=r"$\|\nabla_{\boldsymbol{z}} \log \pi (y_t|\boldsymbol{z}_0)\|$", 
                color=g_color
                )
    
    plt.plot(x, 
         delta_z_mean, 
         label=r"$\|\Delta\boldsymbol{z}\|$", 
         color=z_color)
    

    plt.plot(x, 
                abs_linear_approx_mean, 
                label=r"Upper Bound", 
                color=gz_color,
                )
 

    plt.xlabel("Number of layers")
    plt.title(r"Radius = {}".format(radius))
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    save_path = f"results/figure_results/gradient_vs_delta_z/{model_name_or_path}/{dataset}_grad_vs_delta_z_05.pdf"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()


dataset_list = ["ARC_Challenge", "CommonSenseQA", "MMLU", "OpenBookQA"]
model_list = [  "meta-llama/Llama-3.2-3B",
                "meta-llama/Llama-3.2-1B",
                "Qwen/Qwen1.5-4B",
                "Qwen/Qwen1.5-0.5B",
                "Qwen/Qwen1.5-1.8B"
                ]

for dataset in dataset_list:
    for model_name_or_path in model_list:
        plot_grads_vs_delta_z(dataset, model_name_or_path, "0.50")