#!/usr/bin/env python3
import os
import re
import json
import csv
import argparse
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# -------------------------
# Defaults / constants
# -------------------------
LANGUAGE_ORDER = [
    "even-pairs","repeat-01","parity","cycle-navigation","modular-arithmetic-simple","dyck-2-3","first",
    "majority","stack-manipulation","marked-reversal","unmarked-reversal","marked-copy",
    "missing-duplicate-string","odds-first","binary-addition","binary-multiplication","compute-sqrt","bucket-sort",
]
LANGUAGE_CLASSES = {
    "even-pairs": "R","repeat-01": "R","parity": "R","cycle-navigation": "R","modular-arithmetic-simple": "R",
    "dyck-2-3": "R","first": "R","majority": "DCF","stack-manipulation": "DCF","marked-reversal": "DCF",
    "unmarked-reversal": "CF","marked-copy": "CS","missing-duplicate-string": "CS","odds-first": "CS",
    "binary-addition": "CS","binary-multiplication": "CS","compute-sqrt": "CS","bucket-sort": "CS",
}

PROMPTS = ["io_prompt","zsr_prompt"]
ENCODINGS = ["many_to_one","one_to_one","one_to_many_2","one_to_many_3","one_to_many_4","one_to_many_5"]

PROMPT_LABELS = {
    "io_prompt": "immediate output",
    "zsr_prompt": "zero-shot reasoning",
}
ENCODING_LABELS = {
    "many_to_one": "many \u2192 one",
    "one_to_one": "one \u2192 one",
    "one_to_many_2": "one \u2192 many (t=2)",
    "one_to_many_3": "one \u2192 many (t=3)",
    "one_to_many_4": "one \u2192 many (t=4)",
    "one_to_many_5": "one \u2192 many (t=5)",
}
MOD_PRETTY = {
    "deepseek-chat": "DeepSeek-V3",
    "chatgpt": "GPT-4o mini",
    "Llama-3.1-8B-Instruct": "Llama-3.1 8B",
    "Llama-3.1-70B-Instruct": "Llama-3.1 70B",
    "Qwen2.5-32B-Instruct": "Qwen2.5 32B",
    "Qwen2.5-7B-Instruct": "Qwen2.5 7B",
}

# Robust key sets
_LENGTH_KEYS_NUM = (
    "length", "string_length", "input_length", "seq_len", "sequence_length",
    "len", "n", "n_chars", "string_len", "test_length"
)
_LENGTH_KEYS_TOKENS = (
    "tokens", "chars", "input_tokens", "x_tokens", "sequence_tokens", "symbols"
)
_STRING_KEYS = (
    "string","input","input_string","x","w","s","seq","test_string","sequence","str"
)
_LABEL_KEYS = (
    "label","target","gold","gold_label","true_label",
    "answer","y","expected","gt"
)
_OUTPUT_KEYS = ("model_output","prediction","pred","y_hat","output")
_CORRECT_KEYS = ("is_correct","correct","isCorrect")
_CLEAN_RE = re.compile(r"[<>\[\]\(\)\{\}\|,]")

# -------------------------
# I/O helpers
# -------------------------
def read_latest_json(dir_path):
    try:
        files = [f for f in os.listdir(dir_path) if f.endswith(".json")]
        if not files:
            return None, None
        files.sort()
        fp = os.path.join(dir_path, files[-1])
        with open(fp, "r") as fh:
            data = json.load(fh)
        return data, fp
    except Exception as e:
        print(f"Skipping {dir_path}: {e}")
        return None, None

def collect_icls_with_paths(root):
    """
    Returns:
      accs[lang][model][prompt][enc] = float accuracy
      paths[lang][model][prompt][enc] = filepath
    """
    accs = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    paths = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    if not os.path.isdir(root):
        return accs, paths

    for model in os.listdir(root):
        mdir = os.path.join(root, model)
        if not os.path.isdir(mdir):
            continue
        for prompt in os.listdir(mdir):
            pdir = os.path.join(mdir, prompt)
            if not os.path.isdir(pdir):
                continue
            for enc in os.listdir(pdir):
                edir = os.path.join(pdir, enc)
                if not os.path.isdir(edir):
                    continue
                for lang in os.listdir(edir):
                    ldir = os.path.join(edir, lang)
                    if not os.path.isdir(ldir):
                        continue
                    data, fp = read_latest_json(ldir)
                    if data is None:
                        continue
                    acc = data.get("accuracy", None)
                    if isinstance(acc, (int, float)):
                        accs[lang][model][prompt][enc] = float(acc)
                        paths[lang][model][prompt][enc] = fp
    return accs, paths

# -------------------------
# Normalization helpers
# -------------------------
def _get_first_key(d, keys):
    for k in keys:
        if k in d:
            return k
    return None

def _norm_text(x):
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return str(int(x)) if float(x).is_integer() else str(x)
    s = str(x).strip().lower()
    if s in ("true","false"):
        return "1" if s == "true" else "0"
    if s in ("yes","no"):
        return "1" if s == "yes" else "0"
    if s in ("accept","reject"):
        return "1" if s == "accept" else "0"
    return s

def _infer_length_from_str(s):
    if not isinstance(s, str):
        return None
    s_stripped = s.strip()
    if not s_stripped:
        return None
    s_clean = _CLEAN_RE.sub("", s_stripped)
    if " " in s_clean:
        toks = [t for t in s_clean.split() if t]
        return len(toks)
    return len(s_clean)

# -------------------------
# Context-aware extractor (inherits length from parents)
# -------------------------
def extract_records_ctx(obj, current_len, out_pairs):
    """
    Recursively collect (length:int, correct:bool), inheriting length from parents.
    """
    if isinstance(obj, dict):
        L = current_len

        # numeric length
        k_num = _get_first_key(obj, _LENGTH_KEYS_NUM)
        if k_num is not None and isinstance(obj[k_num], (int, float)):
            L = int(obj[k_num])

        # tokens length
        if L is None:
            k_tok = _get_first_key(obj, _LENGTH_KEYS_TOKENS)
            if k_tok is not None and isinstance(obj[k_tok], (list, tuple)):
                L = len(obj[k_tok])

        # raw string length
        if L is None:
            k_str = _get_first_key(obj, _STRING_KEYS)
            if k_str is not None and isinstance(obj[k_str], str):
                L = _infer_length_from_str(obj[k_str])

        # correctness
        correct = None
        k_corr = _get_first_key(obj, _CORRECT_KEYS)
        if k_corr is not None:
            v = obj[k_corr]
            if isinstance(v, bool):
                correct = v
            elif isinstance(v, (int, float)):
                correct = (v != 0)
            else:
                vv = str(v).strip().lower()
                if vv in ("true","1","yes"):
                    correct = True
                elif vv in ("false","0","no"):
                    correct = False
        else:
            outk = _get_first_key(obj, _OUTPUT_KEYS)
            labk = _get_first_key(obj, _LABEL_KEYS)
            if outk is not None:
                outv = obj[outk]
                if isinstance(outv, (int, float)) and float(outv) == -1.0:
                    correct = False
                elif labk is not None:
                    correct = (_norm_text(outv) == _norm_text(obj[labk]))

        if isinstance(L, int) and L > 0 and isinstance(correct, bool):
            out_pairs.append((L, correct))

        for v in obj.values():
            extract_records_ctx(v, L, out_pairs)

    elif isinstance(obj, list):
        for v in obj:
            extract_records_ctx(v, current_len, out_pairs)

def per_length_counts(json_path):
    """
    Returns: per_len: {length -> (correct_count, total_count)}
    """
    try:
        with open(json_path, "r") as fh:
            data = json.load(fh)
    except Exception as e:
        print(f"Failed reading {json_path}: {e}")
        return {}

    pairs = []
    extract_records_ctx(data, current_len=None, out_pairs=pairs)

    per_len = defaultdict(lambda: [0, 0])
    for L, ok in pairs:
        per_len[L][1] += 1
        if ok:
            per_len[L][0] += 1
    return per_len

# -------------------------
# Binning + plotting
# -------------------------
def bin_accuracy(per_len_counts, bin_width=20):
    """
    Returns list of (lo, hi, acc, total)
    """
    if not per_len_counts:
        return []
    max_L = max(per_len_counts.keys())
    bins = []
    lo = 1  # keep bins anchored at 1 for comparability; change to min(per_len_counts) if you prefer
    while lo <= max_L:
        hi = lo + bin_width - 1
        ok_sum, tot_sum = 0, 0
        for L, (ok, tot) in per_len_counts.items():
            if lo <= L <= hi:
                ok_sum += ok
                tot_sum += tot
        acc = (ok_sum / tot_sum) if tot_sum > 0 else None
        bins.append((lo, hi, acc, tot_sum))
        lo = hi + 1
    return bins

def safe_slug(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_\-]+", "-", s)

def plot_histogram_for_language(
    lang, bins_acc, out_path,
    caption_text=None, img_format="png", dpi=200, bin_width=20
):
    """
    Bar chart: x = length bins (1–20, 21–40, ...), y = accuracy [0,1]
    """
    labels = [f"{lo}–{hi}" for (lo, hi, acc, _) in bins_acc]
    heights = [0.0 if (acc is None) else acc for (_, _, acc, _) in bins_acc]

    plt.figure()
    plt.bar(range(len(heights)), heights)
    plt.xticks(range(len(labels)), labels, rotation=45, ha="right")
    plt.ylim(0.0, 1.0)
    plt.ylabel("Accuracy")
    #plt.xlabel(f"String length (bins of {bin_width})")
    plt.title(f"Accuracy vs. length (best config) — {lang}")

    if caption_text:
        plt.gcf().subplots_adjust(bottom=0.22)
        plt.gcf().text(0.5, 0.04, caption_text, ha="center", va="center", wrap=True)

    plt.tight_layout()
    plt.savefig(out_path, format=img_format, dpi=dpi, bbox_inches="tight")
    plt.close()

# -------------------------
# Main
# -------------------------
def main():
    ap = argparse.ArgumentParser(description="Plot per-language accuracy vs string length (bins) for best ICL config.")
    ap.add_argument("--root", type=str, default="./results/ICL", help="ICL results root")
    ap.add_argument("--out", type=str, default="./tables/accuracy_by_length_imgs", help="output directory for images")
    ap.add_argument("--bin-width", type=int, default=20, help="bin width for string length")
    ap.add_argument("--img-format", type=str, choices=["png","pdf","svg"], default="png", help="image format for per-language files")
    ap.add_argument("--dpi", type=int, default=200, help="DPI for raster formats")
    ap.add_argument("--make-pdf", action="store_true", help="also produce a multi-page PDF with all languages")
    args = ap.parse_args()

    os.makedirs(args.out, exist_ok=True)

    # scan results & locate best config per language
    accs, paths = collect_icls_with_paths(args.root)
    languages = [l for l in LANGUAGE_ORDER if l in accs]

    # CSV summary of chosen configs & binned accuracies
    csv_path = os.path.join(args.out, "best_config_per_language_bins.csv")
    with open(csv_path, "w", newline="") as cf:
        writer = csv.writer(cf)
        writer.writerow([
            "language","class","best_model","best_prompt","best_encoding","best_accuracy",
            "bin_lo","bin_hi","bin_accuracy","bin_total"
        ])

        pdf_pages = PdfPages(os.path.join(args.out, "accuracy_by_length_ALL.pdf")) if args.make_pdf else None

        for lang in languages:
            # choose best config (strict > means first max is kept if tied)
            best_model, best_prompt, best_enc, best_acc = None, None, None, -1.0
            for model, prompts in accs[lang].items():
                for prompt, encs in prompts.items():
                    for enc, a in encs.items():
                        if isinstance(a, (int, float)) and a > best_acc:
                            best_model, best_prompt, best_enc, best_acc = model, prompt, enc, a

            if best_model is None:
                print(f"[warn] No accuracy entries for {lang}, skipping.")
                continue

            best_path = paths[lang][best_model][best_prompt][best_enc]
            per_len = per_length_counts(best_path)
            if not per_len:
                print(f"[warn] No per-example (length, correctness) found in {best_path}. "
                      f"Figure may be empty. Check field names if needed.")
            bins = bin_accuracy(per_len, bin_width=args.bin_width)

            # Caption: model + configuration
            pretty_model = MOD_PRETTY.get(best_model, best_model)
            prompt_label = PROMPT_LABELS.get(best_prompt, best_prompt)
            enc_label = ENCODING_LABELS.get(best_enc, best_enc)
            caption = (
                f"\n\n\n\nString length" 
                f"\nBest configuration — Model: {pretty_model} | "
                f"Prompt: {prompt_label} | Encoding: {enc_label} | "
                f"Accuracy: {best_acc:.2f}"
            )

            # per-language image
            base = f"accuracy_by_length_{safe_slug(lang)}.{args.img_format}"
            out_path = os.path.join(args.out, base)
            plot_histogram_for_language(
                lang, bins, out_path,
                caption_text=caption, img_format=args.img_format, dpi=args.dpi, bin_width=args.bin_width
            )

            # Multi-page PDF
            if pdf_pages is not None:
                labels = [f"{lo}–{hi}" for (lo, hi, acc, _) in bins]
                heights = [0.0 if (acc is None) else acc for (_, _, acc, _) in bins]
                plt.figure()
                plt.bar(range(len(heights)), heights)
                plt.xticks(range(len(labels)), labels, rotation=45, ha="right")
                plt.ylim(0.0, 1.0)
                plt.ylabel("Accuracy")
                #plt.xlabel(f"String length")
                plt.title(f"Accuracy vs. length (best config) — {lang}")
                plt.gcf().subplots_adjust(bottom=0.22)
                plt.gcf().text(0.5, 0.04, caption, ha="center", va="center", wrap=True)
                plt.tight_layout()
                pdf_pages.savefig(bbox_inches="tight")
                plt.close()

            # CSV rows
            lang_class = LANGUAGE_CLASSES.get(lang, "NA")
            for (lo, hi, acc, tot) in bins:
                acc_out = "" if acc is None else f"{acc:.4f}"
                writer.writerow([lang, lang_class, best_model, best_prompt, best_enc, f"{best_acc:.4f}",
                                 lo, hi, acc_out, tot])

            print(f"✓ {lang}: best=({pretty_model}, {best_prompt}, {best_enc} @ {best_acc:.3f}); image → {out_path}")

        if pdf_pages is not None:
            pdf_pages.close()
            print(f"✓ Multi-page PDF → {os.path.join(args.out, 'accuracy_by_length_ALL.pdf')}")

    print(f"CSV summary → {csv_path}")

if __name__ == "__main__":
    main()
