import os
import json
import numpy as np
import matplotlib.pyplot as plt
import random
import statistics

data_concrete = {}
data_abstract = {}

SUFFIX = ".json"
THRESHES = 1.0, 0.5, 0.25, 0.125

first_iter = True
for thresh in THRESHES:
    DIR = f"datasets/gqa/epic_compiled/conformal_exec_{thresh}"

    for filename in os.listdir(DIR):
        if not filename.endswith(SUFFIX): continue

        problem_id = filename[:-len(SUFFIX)]

        with open(os.path.join(DIR, filename), "r") as f:
            d = json.load(f)
        
        if problem_id not in data_abstract:
            assert first_iter
            data_abstract[problem_id] = {}
            data_concrete[problem_id] = []
        else:
            assert not first_iter

        if len(d["abstract_correct"]) == 1:
            data_abstract[problem_id][thresh] = d["abstract_correct"][0]
        else:
            data_abstract[problem_id][thresh] = "unk"
        data_concrete[problem_id].append(d["concrete_correct"])
    first_iter = False

for v in data_abstract.values():
    for t in THRESHES:
        if t not in v:
            v[t] = "err"

concrete_error_count = 0
for v in data_concrete.values():
    while len(v) < len(THRESHES):
        v.append("err")
        concrete_error_count += 1
data_concrete_flat = [x for l in data_concrete.values() for x in l]
# print(concrete_error_count, sum(len(v) for v in data_concrete.values()), len(data_concrete))

def do_procedure(test_target, i):
    data_ordered = [data_abstract[k] for k in sorted(data_abstract.keys())]
    random.seed(i)
    random.shuffle(data_ordered)
    n = len(data_ordered)//2
    val = data_ordered[:n]
    test = data_ordered[n:]

    val_target = test_target*n/(n+1)
    thresh = 0

    for t in THRESHES:
        val_error = sum(p[t] == False for p in val) / len(val)
        # print(num_false, t)
        if val_error < val_target:
            thresh = t
            break
    
    if thresh != 0:
        test_error = sum(p[thresh] == False for p in test) / len(test)
        unk_rate = sum(p[thresh] == "unk" or p[thresh] == "err" for p in test) / len(test)
        # print(f"{thresh}\t{test_error:0.2f}")
    else:
        test_error = None
        unk_rate = None
    return test_error, unk_rate

TEST_TARGET = 0.1

test_errors = tuple(do_procedure(TEST_TARGET, i) for i in range(100))
test_errors_sat = tuple(x for x, _t in test_errors if x is not None)
test_errors_unsat = sum(1 for x, _t in test_errors if x is None)
test_errors_filled = tuple(x if x is not None else 0 for x, _t in test_errors)
test_frac_unk = tuple(unk_rate for _, unk_rate in test_errors if unk_rate is not None)
print("Dataset size:", len(data_abstract))
print(f"coverage: {statistics.mean(test_errors_sat)*100:0.1f}% \\pm {statistics.stdev(test_errors_sat)*100:0.1f}, {test_errors_unsat/len(test_errors)}")
print(f"uncertain: {statistics.mean(test_frac_unk)*100:0.1f}% \\pm {statistics.stdev(test_frac_unk)*100:0.1f}, {test_errors_unsat/len(test_errors)}")
print(statistics.mean(test_errors_filled), statistics.stdev(test_errors_filled))

test_frac_unk = tuple(unk_rate for _, unk_rate in test_errors if unk_rate is not None)

import matplotlib.pyplot as plt
import seaborn as sns

# Set style
def make_grid_box(ax):
    ax.grid(True, which='major', axis='x', linestyle='--', linewidth=1.5, alpha=0.7)
    
    # Enable all 4 spines (frame)
    for spine in ['top', 'bottom', 'left', 'right']:
        ax.spines[spine].set_visible(True)
        ax.spines[spine].set_color("gray")
        ax.spines[spine].set_linewidth(1.5)

    # Optional: remove y ticks since you're hiding the y-axis values
    ax.set_yticks([])

# Configs
tick_fontsize = 18
axis_label_fontsize = 20
line_width = 2.5
box_linewidth = 2
fliersize = 8

# Colors
error_color = "#cf77f5"     
unknown_color = "#ff7f0e"   
target_line_color = "#d62728"  

# --- Plot 1: Test Error Boxplot ---
fig, ax = plt.subplots(figsize=(10, 2))
sns.boxplot(
    x=test_errors_sat,
    color=error_color,
    width=0.5,
    linewidth=box_linewidth,
    fliersize=fliersize,
    flierprops=dict(marker='o', markerfacecolor='black', markeredgecolor='black', markersize=fliersize, markeredgewidth=2),
    ax=ax,
    orient="h"
)
ax.axvline(TEST_TARGET, linestyle="--", linewidth=line_width, color=target_line_color)
ax.set_xlim(0, 1)
ax.set_yticks([])
xticks = list(ax.get_xticks()) + [TEST_TARGET]
ax.set_xticks(xticks)
ax.tick_params(axis="x", labelsize=tick_fontsize)
ax.set_xlabel("coverage rate", fontsize=axis_label_fontsize)
# Color only the target tick
for label in ax.get_xticklabels():
    if np.isclose(float(label.get_text()), TEST_TARGET, atol=1e-3):
        label.set_color(target_line_color)
sns.despine(left=True)
make_grid_box(ax)
plt.tight_layout()
plt.savefig("/tmp/conformal_error.pdf")
plt.close()

# --- Plot 2: Unknown Fraction Boxplot ---
fig, ax = plt.subplots(figsize=(10, 2))
sns.boxplot(
    x=test_frac_unk,
    color=unknown_color,
    width=0.5,
    linewidth=box_linewidth,
    fliersize=fliersize,
    flierprops=dict(marker='o', markerfacecolor='black', markeredgecolor='black', markersize=fliersize, markeredgewidth=2),
    ax=ax,
    orient="h"
)
ax.set_xlim(0, 1)
ax.set_yticks([])
ax.tick_params(axis="x", labelsize=tick_fontsize)
ax.set_xlabel("fraction of predictions uncertain", fontsize=axis_label_fontsize)
sns.despine(left=True)
make_grid_box(ax)
plt.tight_layout()
plt.savefig("/tmp/conformal_unk.pdf")
plt.close()



if False:
    rate_false = sum(x == False for x in data_concrete_flat) / len([x for x in data_concrete_flat if x != "err"])
    rate_true = sum(x == True for x in data_concrete_flat) / len([x for x in data_concrete_flat if x != "err"])
    print(f"Conc\t{rate_false*100:0.1f}\t----\t{rate_true*100:0.1f}")
    for t in THRESHES:
        rate_false = sum(p[t] == False for p in data_abstract.values()) / len(data_abstract)
        rate_true = sum(p[t] == True for p in data_abstract.values()) / len(data_abstract)
        rate_unk = sum(p[t] == "unk" for p in data_abstract.values()) / len(data_abstract)
        rate_err = sum(p[t] == "err" for p in data_abstract.values()) / len(data_abstract)
        rate_unkerr = 1 - rate_false - rate_true
        # print(f"{t}\t{rate_false*100:0.1f}\t{rate_true*100:0.1f}\t{rate_unk*100:0.1f}\t{rate_err*100:0.1f}")
        print(f"{t}\t{rate_false*100:0.1f}\t{rate_unkerr*100:0.1f}\t{rate_true*100:0.1f}")
