import numpy as np
import json
import matplotlib.pyplot as plt

import os

def plot_figure(
        radius_list,
        mean_r_of_current_dataset,
        std_r_of_current_dataset,
        mean_loss_of_current_dataset,
        std_loss_of_current_dataset,
        save_path,
        title):

    xs = list(range(len(radius_list)))

    fig, ax = plt.subplots(figsize=(3, 3))

    means_r = np.array(mean_r_of_current_dataset)
    stds_r = np.array(std_r_of_current_dataset)
    means_loss = np.array(mean_loss_of_current_dataset)
    stds_loss = np.array(std_loss_of_current_dataset)

    r_color = "#0D4C6D"
    loss_color = "#BF1E2E"
    line_color = "#FF9E02"

    ax.plot(xs, means_r, label=r"$Pearson's\ r$", color=r_color)
    ax.fill_between(xs, means_r - stds_r, means_r + stds_r, alpha=0.20, color=r_color)

    ax.plot(xs, means_loss, label=r"$\mathcal{O}(\|\boldsymbol{z}_1-\boldsymbol{z}_0\|^2)$", color=loss_color)
    ax.fill_between(xs, means_loss - stds_loss, means_loss + stds_loss, alpha=0.20, color=loss_color)

    
    ax.axvline(x=72, color=line_color, linestyle="--", label=r"$\rho=0.72$")
    

    ax.set_xticks([0, 20, 40, 60, 80, 100])
    ax.set_xticklabels([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_title(title)
    ax.set_xlabel("Radius")

    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()

def pearson_similarity(x, y):
    x = np.array(x)
    y = np.array(y)
    return np.corrcoef(x, y)[0, 1]

def plot_r_item(
    dataset, model_name_or_path, layer_index
):

    file_path_template = "results/data_results/turb/{}/d-{}_results.jsonl"
    
    radius_list = []
    mean_r_of_current_dataset = []
    std_r_of_current_dataset = []
    mean_loss_of_current_dataset = []
    std_loss_of_current_dataset = []

    layer_number_list = []

    file_path = file_path_template.format(model_name_or_path, dataset)

    r_of_all_prompts = dict()
    loss_of_all_prompts = dict()
    with open(file_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            item = json.loads(line)
            question_id = item["question_id"]
            prompt_key = item["prompt_key"]
            radius = item["radius"]
            delta_log_prob_list = item["delta_log_prob_list"]
            linear_approx_list = item["linear_approx_list"]
            loss_list = item["loss_list"]
            delta_z_list = item["delta_z_list"] #[radius_number, layer_number]

            if radius not in radius_list:
                radius_list.append(radius)

            layer_number_list = [str(i) for i in range(len(delta_z_list[0]))]
            
            embedding_linear_approx_list = [i[layer_index] for i in linear_approx_list]
            embedding_delta_z_list = [i[layer_index] for i in delta_z_list]
            embedding_loss_list = [i[layer_index] for i in loss_list]

            
            r = pearson_similarity(delta_log_prob_list, embedding_linear_approx_list)
            r_of_all_prompts.setdefault(radius, []).append(float(r))
            loss_of_all_prompts.setdefault(radius, []).append(float(np.mean(embedding_loss_list)))

    mean_r_of_current_dataset = [np.mean(r_of_all_prompts[r]) for r in radius_list]
    std_r_of_current_dataset = [np.std(r_of_all_prompts[r]) for r in radius_list]

    mean_loss_of_current_dataset = [np.mean(loss_of_all_prompts[r]) for r in radius_list]
    std_loss_of_current_dataset = [np.std(loss_of_all_prompts[r]) for r in radius_list]
    
    plot_figure(
        radius_list,
        mean_r_of_current_dataset,
        std_r_of_current_dataset,
        mean_loss_of_current_dataset,
        std_loss_of_current_dataset,
        save_path=f"results/figure_results/r_vs_loss/{model_name_or_path}/{dataset}_l-{layer_number_list[layer_index]}_r_loss.pdf",
        title=r"Pearson's $r$ vs. $\mathcal{O}(\|\boldsymbol{z}_1-\boldsymbol{z}_0\|^2)$"
    )


dataset_list = ["MMLU", "OpenBookQA", "CommonSenseQA", "ARC_Challenge"]

model_list = [
                "Qwen/Qwen1.5-4B",
                "Qwen/Qwen1.5-0.5B",
                "Qwen/Qwen1.5-1.8B",
                "meta-llama/Llama-3.2-1B",
                "meta-llama/Llama-3.2-3B",
                ]

for dataset in dataset_list:
    for model_name_or_path in model_list:
        plot_r_item(dataset, model_name_or_path,layer_index=0)
        plot_r_item(dataset, model_name_or_path,layer_index=-1)