
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams["font.family"] = "Times New Roman"

sns.set_theme(style="whitegrid")

# === Path and config ===
base_dir = "/input_path"  #where your MCL results are stored
datasets = ["reddit", "cnn", "wiki", "gov"]
models = ["llama", "mistral", "qwen"]
model_names = ["Llama", "Mistral", "Qwen"]
num_samples = 100

pretty_names = {
    "reddit": "Writing Prompts",
    "cnn": "News Articles",
    "wiki": "Wikipedia",
    "gov": "Government Reports"
}

fig, axes = plt.subplots(len(models), len(datasets), figsize=(20, 6), sharex=False, sharey=False)


import numpy as np
from scipy.optimize import curve_fit

def power_law(x, a, b):
    return a * x ** (-b)

for i, model in enumerate(models):
    for j, dataset in enumerate(datasets):
        ax = axes[i, j]
        folder = f"{dataset}_{model}_samples{num_samples}"
        path = os.path.join(base_dir, folder, "context_lengths.json")

        conf_lens = []
        if os.path.exists(path):
            with open(path, "r") as f:
                data = json.load(f)
            for story in data.values():
                for tok_info in story.values():
                    val = tok_info.get("conf_ctx_len")
                    if val is not None and val >= 32:
                        conf_lens.append(val)

        if conf_lens:
            max_len = max(conf_lens)
            bin_size = 64 if dataset == "gov" else 16
            bins = list(range(32, max_len + bin_size, bin_size))
            hist_vals, bin_edges = np.histogram(conf_lens, bins=bins)
            ax.set_yscale("log")

            # Exclude the first bin for fitting
            fit_x = 0.5 * (bin_edges[0:-1] + bin_edges[1:])  # bin midpoints
            fit_y = hist_vals[0:]

            # Filter out zero counts for log fitting
            valid = fit_y > 0
            fit_x = fit_x[valid]
            fit_y = fit_y[valid]

            b_val = None
            if len(fit_x) >= 2:
                log_x = np.log(fit_x)
                log_y = np.log(fit_y)
                try:
                    popt, _ = curve_fit(lambda x, a, b: a + b * x, log_x, log_y)
                    a, b = popt
                    b_val = b
                    # Transform back to original space
                    x_fit = np.linspace(min(fit_x), max(fit_x), 100)
                    y_fit = np.exp(a) * x_fit ** b
                    ax.plot(x_fit, y_fit, color='red', linewidth=2, label='Power-law fit')
                except Exception as e:
                    print(f"⚠️ Fit failed for {dataset}-{model}: {e}")

            # Plot histogram in linear scale
            ax.hist(conf_lens, bins=bins, alpha=0.8, color="green", edgecolor="black")

            # p(l_i ≤ 64) annotation
            pct_below_64 = int(100 * sum(x <= 64 for x in conf_lens) / len(conf_lens) + 1)
            ax.text(0.98, 0.95,
                    rf"$p(l_i \leq 64) = {pct_below_64:.1f}\%$",
                    transform=ax.transAxes,
                    fontsize=13,
                    ha='right', va='top')

            # Power-law exponent annotation
            if b_val is not None:
                ax.text(0.98, 0.8,
                        rf"$\hat{{b}} = {b_val:.2f}$",
                        transform=ax.transAxes,
                        fontsize=12,
                        ha='right', va='top',
                        color='red')
        else:
            ax.text(0.5, 0.5, "No Data", ha="center", va="center", fontsize=14, color="red")
            ax.set_xlim(0, 128)

        if i == len(models) - 1:
            ax.tick_params(axis='x', labelsize=12)
        else:
            ax.set_xticklabels([])

        if j == 0:
            ax.tick_params(axis='y', labelsize=12)
        else:
            ax.set_yticklabels([])

        if i == 0:
            ax.set_title(pretty_names[dataset], fontsize=16)

        if j == len(datasets) - 1:
            model_name = model_names[i]
            ax.annotate(model_name, xy=(1.05, 0.5), xycoords='axes fraction',
                        fontsize=16, ha='left', va='center', rotation=270)



plt.tight_layout(rect=[0.03, 0, 1, 1])
fig.text(0.02, 0.5, "Token Count (log scale)", va='center', rotation='vertical', fontsize=16)
fig.text(0.5, 0.005, "Minimum Context Length (MCL)", ha='center', fontsize=16)


save_path = os.path.join(base_dir, "combined_conf_ctxlen_histograms_all_4datasets_4Pager_withFit.png")
plt.savefig(save_path)
plt.close()

print(f"✅ Saved plot to: {save_path}")
