from matplotlib import pyplot as plt
from matplotlib.ticker import LogLocator, ScalarFormatter
from datasets import load_dataset
import os
from collections import Counter
import matplotlib.gridspec as gridspec
import numpy as np


n=256  # number of steps
dataset_name = "HuggingFaceH4/Llama-3.2-1B-Instruct-beam-search-completions"
dataset_subset = f"HuggingFaceH4_MATH-500--T-0.8--top_p-1.0--n-{n}--m-4--iters-40--look-0--seed-0--agg_strategy-last"
figure_dir = "data/plots"
figure_name = "step_score_plot--n-" + f"{n}"

dataset = load_dataset(dataset_name, name=dataset_subset, split="train").select(range(0,500))

dataset = dataset.select_columns(
    ["problem", "answer", "subject", "level", "unique_id", "scores", "agg_scores", f"pred_naive@{n}"]
)

def compute_max_score(example):
    answers = "\\boxed{" + example["answer"] + "}"
    best_idx = example["agg_scores"].index(max(example["agg_scores"]))
    scores = example["scores"][best_idx]
    agg_scores = example["agg_scores"][best_idx]
    is_correct = example[f"pred_naive@{n}"] == answers
    return {"answer": answers, "scores": scores, "agg_scores": agg_scores, "is_correct": is_correct}

dataset = dataset.map(compute_max_score)

# plotting
# x = idx of scores (steps)
# y = scores
# blue for wrong answers (pred_naive@4 != answer)
# red for correct answers (pred_naive@4 == answer)

# save dataset to jsonl
os.makedirs(figure_dir, exist_ok=True)
dataset.to_json(os.path.join(figure_dir, f"{figure_name}.jsonl"), lines=True, force_ascii=False)


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12), sharey=True)

all_deltas_1 = []
all_deltas_2 = []
all_deltas_3 = []
for example in dataset:
    scores = example["scores"]
    all_deltas_1.extend(abs(scores[i] - scores[i - 1]) for i in range(1, len(scores)))
    all_deltas_2.extend(abs(scores[i] - scores[i - 2]) for i in range(2, len(scores)))
    all_deltas_3.extend(abs(scores[i] - scores[i - 3]) for i in range(3, len(scores)))

n_bins = 4

# d_min_1, d_max_1 = min(all_deltas_1), max(all_deltas_1)
# print(f"Min delta: {d_min_1}, Max delta: {d_max_1}")
# bin_width_1 = (d_max_1 - d_min_1) / n_bins
# bins_1 = [d_min_1 + i * bin_width_1 for i in range(n_bins)] + [d_max_1 + 1e-6]  # ensure max included
# bin_labels_1 = [f"{bins_1[i]:.2f} - {bins_1[i + 1]:.2f}" for i in range(n_bins)]
bins_1 = [0, 0.02, 0.05, 0.1, 0.15, 0.25, 1.0]
bin_labels_1 = [f"{bins_1[i]:.2f} - {bins_1[i + 1]:.2f}" for i in range(len(bins_1) - 1)]

# d_min_2, d_max_2 = min(all_deltas_2), max(all_deltas_2)
# print(f"Min delta: {d_min_2}, Max delta: {d_max_2}")
# bin_width_2 = (d_max_2 - d_min_2) / n_bins
# bins_2 = [d_min_2 + i * bin_width_2 for i in range(n_bins)] + [d_max_2 + 1e-6]  # ensure max included
# bin_labels_2 = [f"{bins_2[i]:.2f} - {bins_2[i + 1]:.2f}" for i in range(n_bins)]
bins_2 = [0, 0.02, 0.05, 0.1, 0.15, 0.25, 1.0]
bin_labels_2 = [f"{bins_2[i]:.2f} - {bins_2[i + 1]:.2f}" for i in range(len(bins_2) - 1)]

# d_min_3, d_max_3 = min(all_deltas_3), max(all_deltas_3)
# print(f"Min delta: {d_min_3}, Max delta: {d_max_3}")
# bin_width_3 = (d_max_3 - d_min_3) / n_bins
# bins_3 = [d_min_3 + i * bin_width_3 for i in range(n_bins)] + [d_max_3 + 1e-6]  # ensure max included
# bin_labels_3 = [f"{bins_3[i]:.2f} - {bins_3[i + 1]:.2f}" for i in range(n_bins)]
bins_3 = [0, 0.02, 0.05, 0.1, 0.15, 0.25, 1.0]
bin_labels_3 = [f"{bins_3[i]:.2f} - {bins_3[i + 1]:.2f}" for i in range(len(bins_3) - 1)]

bin_counts_1 = Counter({label: 0 for label in bin_labels_1})
bin_counts_2 = Counter({label: 0 for label in bin_labels_2})
bin_counts_3 = Counter({label: 0 for label in bin_labels_1})


for example in dataset:
    scores = example["scores"]
    delta_scores_1 = [abs(scores[i] - scores[i - 1]) for i in range(1, len(scores))]
    for diff in delta_scores_1:
        for i in range(len(bins_1) - 1):
            if bins_1[i] <= diff < bins_1[i + 1]:
                label = bin_labels_1[i]
                bin_counts_1[label] += 1

for example in dataset:
    scores = example["scores"]
    delta_scores_2 = [abs(scores[i] - scores[i - 2]) for i in range(2, len(scores))]
    for diff in delta_scores_2:
        for i in range(len(bins_2) - 1):
            if bins_2[i] <= diff < bins_2[i + 1]:
                label = bin_labels_2[i]
                bin_counts_2[label] += 1

for example in dataset:
    scores = example["scores"]
    delta_scores_3 = [abs(scores[i] - scores[i - 3]) for i in range(3, len(scores))]
    for diff in delta_scores_3:
        for i in range(len(bins_3) - 1):
            if bins_3[i] <= diff < bins_3[i + 1]:
                label = bin_labels_3[i]
                bin_counts_3[label] += 1

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), sharey=True)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10), sharey=True)

# exchange bin1 and bin4
# bin_labels_1[0], bin_labels_1[3] = bin_labels_1[3], bin_labels_1[0]
# bin_labels_2[0], bin_labels_2[3] = bin_labels_2[3], bin_labels_2[0]

# add quantiles: 25%, 50%, 75% to the labels
# pencentage_labels = ["0-25%", "25-50%", "50-75%", "75-100%"]
# legend_labels_1 = [f"{bin_labels_1[i]} ({pencentage_labels[i]})" for i in range(n_bins)]
# legend_labels_2 = [f"{bin_labels_2[i]} ({pencentage_labels[i]})" for i in range(n_bins)]
# legend_labels_3 = [f"{bin_labels_3[i]} ({pencentage_labels[i]})" for i in range(n_bins)]
legend_labels_1 = [f"{bin_labels_1[i]}" for i in range(len(bin_labels_1))]
legend_labels_2 = [f"{bin_labels_2[i]}" for i in range(len(bin_labels_2))]
legend_labels_3 = [f"{bin_labels_3[i]}" for i in range(len(bin_labels_3))]



values_1 = [bin_counts_1[l] for l in bin_labels_1]
values_2 = [bin_counts_2[l] for l in bin_labels_2]
values_3 = [bin_counts_3[l] for l in bin_labels_3]

wedges1, texts1, autotexts1 = ax1.pie(
    values_1,
    labels=None,
    autopct="%1.1f%%",
    startangle=90,
    pctdistance=1.15,
    textprops={"fontsize": 12},
    colors = ['yellowgreen', 'gold', 'lightcoral', 'lightskyblue', 'lightpink', 'lightgrey', 'lightblue']
)
ax1.set_title(f"ΔScore Distribution (n={n}) for step=1", fontsize=20)
# ax1.legend(
#     wedges1, legend_labels_1,
#     title="ΔScore Bins",
#     loc="center left",
#     bbox_to_anchor=(1, 0, 0.5, 1),
#     fontsize=14
# )

wedges2, texts2, autotexts2 = ax2.pie(
    values_2,
    labels=None,
    autopct="%1.1f%%",
    startangle=90,
    pctdistance=1.15,
    textprops={"fontsize": 12},
    colors = ['yellowgreen', 'gold', 'lightcoral', 'lightskyblue', 'lightpink', 'lightgrey', 'lightblue']
)
ax2.set_title(f"ΔScore Distribution (n={n}) for step=2", fontsize=20)
# ax2.legend(
#     wedges2, legend_labels_2,
#     title="ΔScore Bins",
#     loc="center left",
#     bbox_to_anchor=(1, 0, 0.5, 1),
#     fontsize=14
# )

wedges3, texts3, autotexts3 = ax3.pie(
    values_3,
    labels=None,
    autopct="%1.1f%%",
    startangle=90,
    pctdistance=1.15,
    textprops={"fontsize": 12},
    colors = ['yellowgreen', 'gold', 'lightcoral', 'lightskyblue', 'lightpink', 'lightgrey', 'lightblue']
)
ax3.set_title(f"ΔScore Distribution (n={n}) for step=3", fontsize=20)
ax3.legend(
    wedges3, legend_labels_3,
    title="ΔScore Bins",
    loc="center left",
    bbox_to_anchor=(1, 0, 0.45, 1),
    fontsize=14
)


plt.subplots_adjust(right=0.85)  # leave space for legends
plt.tight_layout()
plt.savefig(os.path.join(figure_dir, f"{figure_name}_pie_diff.png"), dpi=300)
plt.close()