import json, re, argparse, csv
from pathlib import Path
from collections import defaultdict
import pdb
import sys

try:
    sys.set_int_max_str_digits(0)
except Exception:
    pass

def load_results(p):
    with open(p, encoding="utf-8") as f:
        try:
            return json.load(f)
        except ValueError as e:
            if "Exceeds the limit" not in str(e):
                raise
            
    get_lim = getattr(sys, "get_int_max_str_digits", None)
    set_lim = getattr(sys, "set_int_max_str_digits", None)
    old = get_lim() if get_lim else None
    try:
        if set_lim:
            set_lim(0)
        with open(p, encoding="utf-8") as f:
            return json.load(f)
    finally:
        if set_lim and old is not None:
            set_lim(old)

def gt_is_assert(r):
    return isinstance(r["groundtruth_result"], str) and "AssertionError" in r["groundtruth_result"]

def model_is_ok(r):
    res_str = str(r.get("model_result", ""))
    if res_str in ["Timeout", "InvalidInput", "NoFunction"]:
        return False
    if re.match(r"^(Exception|[A-Za-z_]*Error):", res_str):
        return False
    if re.search(r"error", res_str, re.IGNORECASE):
        return False
    return True

def collect_inputs(rows, keep):
    bucket = defaultdict(lambda: defaultdict(list))
    for r in rows:
        if keep(r):
            sec_list = bucket[r["task_id"]][r["section"]]
            if r["input"] not in sec_list:
                sec_list.append(r["input"])
    return bucket

def rewrite_dataset(src_jsonl, dst_jsonl, task2sec2inputs):
    out = []
    removed_ids = []

    for line in open(src_jsonl, encoding="utf-8"):
        obj = json.loads(line)
        tid = obj["name"]
        parts = task2sec2inputs.get(tid, {})
        new_prod = {
            sec: [{"input": i} for i in inputs]
            for sec, inputs in parts.items() if inputs
        }
        if new_prod:
            obj["grammar"][0]["production"] = [json.dumps(new_prod, ensure_ascii=False)]
            out.append(obj)
        else:
            removed_ids.append(tid)

    dst_jsonl.parent.mkdir(parents=True, exist_ok=True)
    with open(dst_jsonl, "w", encoding="utf-8") as fw:
        for o in out:
            fw.write(json.dumps(o, ensure_ascii=False) + "\n")

    return len(out), removed_ids

def make_two(results_json, src_jsonl, out_dir: Path, basename: str):
    out_dir.mkdir(parents=True, exist_ok=True)
    rows = load_results(results_json)
    summary = [] 

    if 'assert_specification' in results_json:
        only_path = out_dir / f"{basename}_gt_assert_only.jsonl"
        ok_path   = out_dir / f"{basename}_gt_assert_and_model_ok.jsonl"
        gt_inputs = collect_inputs(rows, gt_is_assert)
        ok_inputs = collect_inputs(rows, lambda r: gt_is_assert(r) and model_is_ok(r))

        cnt1, rem1 = rewrite_dataset(src_jsonl, only_path, gt_inputs)
        summary.append((only_path.name, cnt1, rem1))

        cnt2, rem2 = rewrite_dataset(src_jsonl, ok_path, ok_inputs)
        summary.append((ok_path.name, cnt2, rem2))

    elif 'functionality_specification' in results_json:
        ok_path = out_dir / f"{basename}_gt_and_model_ok.jsonl"
        gt_inputs = collect_inputs(rows, lambda r: not gt_is_assert(r))
        ok_inputs = collect_inputs(rows, lambda r: not gt_is_assert(r) and model_is_ok(r))

        cnt, rem = rewrite_dataset(src_jsonl, ok_path, ok_inputs)
        summary.append((ok_path.name, cnt, rem))

    csv_path = out_dir / f"{basename}_counts.csv"
    with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['file', 'count', 'removed_task_ids'])
        for fname, count, removed in summary:
            removed_str = ';'.join(removed)
            writer.writerow([fname, count, removed_str])

    print(f"Saved counts summary CSV (with removed task_ids) to: {csv_path}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--results", required=True, help="chatgpt-generated_step_all_results.json")
    ap.add_argument("--src_jsonl", required=True, help="Original Humaneval/MBPP JSONL")
    ap.add_argument("--output_dir", required=True, help="Output folder (created if not exists)")
    ap.add_argument("--basename", default="dataset", help="Output file prefix")
    args = ap.parse_args()

    make_two(args.results, args.src_jsonl, Path(args.output_dir), args.basename)
