import math

from measure_coverage_patch import main as measure_coverage_patch, save_div
from measure_coverage_patch_multi import main as measure_coverage_patch_multi
appraoches =[
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=0,temperature=0__test/mode_custom",
     " - & - & - "),  # "RAG"),
    ("gpt-4-1106-preview__swt_bench_aug1_bm25_cl27k_oracle_test__seed=0,temperature=0__test/mode_custom",
     "\\cmark & - & - "),  # RAG with oracle test
    ("gpt-4-1106-preview__swt_bench_aug1_oracle__seed=0,temperature=0__test/mode_custom",
     "\\cmark & - & \\cmark "),  # "Oracle"),
    ("gpt-4-1106-preview__swt_bench_aug1_oracle_patch__seed=0,temperature=0__test/mode_custom",
     "\\cmark & \\cmark & \\cmark "),  # "Oracle + GP"),
    ("gpt-4-1106-preview__swt_bench_aug1_oracle_failing_patch__seed=0,temperature=0__test/mode_custom",
     "\\cmark & \\xmark & \\cmark"),  # "Oracle + FP"),
    ("gpt-4-1106-preview__swt_bench_aug1_oracle_failing_patch_consistent__seed=0,temperature=0__test/mode_custom",
     "\\cmark & \\xmark & \\xmark"),  # "Consistent + FP"),
]

all_ress = []
for approach in appraoches:
    eval_dir = approach[0]
    name = approach[1]

    ress = measure_coverage_patch(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", fuzzy=False)
    all_ress.append(ress)

agreed_all_cases = set([res["instance_id"] for res in all_ress[0] if res.get("message", None) != "no eval output found"])
for ress in all_ress[1:]:
    agreed_all_cases.intersection_update([res["instance_id"] for res in ress if res.get("message", None) != "no eval output found"])

for (eval_dir, name), ress in zip(appraoches, all_ress):

    ress = [res for res in ress if res["instance_id"] in agreed_all_cases]
    no_all_cases = len(ress)
    applied_cases = [res for res in ress if res.get("message", None) is None]
    no_applied_cases = len(applied_cases)
    ftx = sum(res["fails_initially"] or res["error_initially"] for res in applied_cases)
    good_cases = sum(res["good_case"] for res in applied_cases)

    ptp = sum(not (res["fails_initially"] or res["error_initially"]) for res in applied_cases)

    print(
        name + " & " + " & ".join(map(lambda x: f"{x*100:.1f}", [
            no_applied_cases/no_all_cases,
            save_div(ftx, no_all_cases, "NA"),
            save_div(good_cases, no_all_cases, "NA"),
            save_div(ptp, no_all_cases, "NA")
        ])),
        end=" \\\\ \n"
    )
