import re, os, json
from pathlib import Path
import numpy as np
import pandas as pd

# Targets for this batch
y_threshold = 0.95
s_thresholds = [0.7, 0.75, 0.8, 0.85, 0.9, 0.95]

# Build file paths
base = Path("/home/disk2/lhr/fairDomainAdaption/mine/mine/logs/threshold/syn")
files = {
    st: base / f"syn_{y_threshold}_{str(st).rstrip('0').rstrip('.')}.log"
    for st in s_thresholds
}

# Regex patterns to capture the "mean ± var" lines at the end
pat_acc = re.compile(r'^\s*Acc:\s*([0-9]+(?:\.[0-9]+)?)\s*±\s*([0-9]+(?:\.[0-9]+)?)', re.MULTILINE)
pat_auc = re.compile(r'^\s*auc_roc:\s*([0-9]+(?:\.[0-9]+)?)\s*±\s*([0-9]+(?:\.[0-9]+)?)', re.MULTILINE)
pat_par = re.compile(r'^\s*parity:\s*([0-9]+(?:\.[0-9]+)?)\s*±\s*([0-9]+(?:\.[0-9]+)?)', re.MULTILINE)
pat_eq = re.compile(r'^\s*equality:\s*([0-9]+(?:\.[0-9]+)?)\s*±\s*([0-9]+(?:\.[0-9]+)?)', re.MULTILINE)

def extract_metrics(text):
    # search from the end to capture the last block
    # Reverse the text lines to find the last occurrences robustly
    # But regex with findall and take last match is easier
    acc = pat_acc.findall(text)
    auc = pat_auc.findall(text)
    par = pat_par.findall(text)
    eq = pat_eq.findall(text)
    def last_mean(lst):
        if not lst:
            return None
        return float(lst[-1][0])
    return last_mean(acc), last_mean(auc), last_mean(par), last_mean(eq)

results = []
missing = []
for st, path in files.items():
    if not path.exists():
        missing.append((st, str(path)))
        results.append((st, None, None, None, None))
        continue
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        print(path)
        text = f.read()
    acc, auc, par, eq = extract_metrics(text)
    results.append((st, acc, auc, par, eq))

# Sort by s-threshold to ensure consistent ordering
results.sort(key=lambda x: x[0])

# Build numpy arrays (means only)
acc_arr = np.array([r[1] for r in results], dtype=float)
auc_arr = np.array([r[2] for r in results], dtype=float)
par_arr = np.array([r[3] for r in results], dtype=float)
eq_arr = np.array([r[4] for r in results], dtype=float)

# Save arrays
np.save(base / f"bail_y{y_threshold}_acc_means.npy", acc_arr)
np.save(base / f"bail_y{y_threshold}_auc_means.npy", auc_arr)
np.save(base / f"bail_y{y_threshold}_parity_means.npy", par_arr)
np.save(base / f"bail_y{y_threshold}_equality_means.npy", eq_arr)

# Also prepare a CSV for easy viewing
df = pd.DataFrame({
    "s_threshold": [r[0] for r in results],
    "acc_mean": [r[1] for r in results],
    "auc_roc_mean": [r[2] for r in results],
    "parity_mean": [r[3] for r in results],
    "equality_mean": [r[4] for r in results],
})
# csv_path = base / f"bail_y{y_threshold}_summary_means.csv"
# df.to_csv(csv_path, index=False)

# Show table to user
# display_dataframe_to_user(f"bail_y{y_threshold}_summary_means", df)

acc_arr, auc_arr, par_arr, eq_arr, missing[:3]

print(list(acc_arr), list(auc_arr), list(par_arr), list(eq_arr), missing[:3])