import math

from measure_coverage_patch import main as measure_coverage_patch, save_div
from measure_coverage_patch_multi import main as measure_coverage_patch_multi

print(",".join([
    "approach",
    "pct_lines_covered | good_case",
    "pct_missed_lines_covered | good_case",
    "pct_any_line_covered | good_case",
    "pct_any_missed_line_covered | good_case",
    "pct_lines_covered | bad_case",
    "pct_missed_lines_covered | bad_case",
    "pct_any_line_covered | bad_case",
    "pct_any_missed_line_covered | bad_case",
]))

cmp_pass_cases = None
for union_cases in [
    "swe-agent-demo3__swt_bench_lite__test/mode_vanilla",
    "swt_lite_golden_test/mode_vanillafuzzy",
]:
    cmp_eval_dir ="swe-agent-demo3__swt_bench_lite__test/mode_vanilla"
    cmp_coverage_patch = []
    measure_coverage_patch(f"evaluation_output/{cmp_eval_dir}", log=lambda x: cmp_coverage_patch.append(x), dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test")
    local_cases = {res["instance_id"] for res in cmp_coverage_patch if res.get("good_case", -1) == 1}
    if cmp_pass_cases is None:
        cmp_pass_cases = local_cases
    else:
        cmp_pass_cases &= local_cases

for approach in [
    ("swt_lite_golden_test/mode_vanillafuzzy", "golden"),
    # ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_diff_27k_cl100k__seed=0,temperature=0__test/mode_vanillafuzzy", "gpt4 one-shot diff"),
    # ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=0,temperature=0__test/mode_custom", "gpt4 one-shot prompt"),
    # ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=1,temperature=07__test/mode_custom", "gpt4 libro"),
    # ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=1,temperature=07__test/mode_custom", "gpt4 libro (ideal)"),
    ("swe-agent-demo3__swt_bench_lite__test/mode_vanilla", "gpt4 swe-agent"),
    # ("acr_swt_bench_lite/mode_vanilla", "gpt4 acr"),
    # ("claude-3-haiku-20240307__swt_bench_lite_aug1_bm25_diff_27k_cl100k__seed=0,temperature=0__test/mode_vanillafuzzy", "haiku one-shot diff"),
    # ("claude-3-haiku-20240307__swt_bench_lite_aug1_bm25_27k_cl100k__seed=0,temperature=0__test/mode_custom", "haiku one-shot prompt"),
    # ("swe-agent-demo3-haiku__swt_bench_lite__test/mode_vanilla", "haiku swe-agent"),
    # ("Mixtral-8x22B-Instruct-v01__swt_bench_lite_aug1_bm25_diff_27k_cl100k__temperature=0__test/mode_vanillafuzzy", "Mixtral 8x22B one-shot diff"),
    # ("Mixtral-8x22B-Instruct-v01__swt_bench_lite_aug1_bm25_27k_cl100k__temperature=0__test/mode_custom", "Mixtral 8x22B one-shot prompt"),
    # ("ollama_llama3-gradient:70b__swt_bench_lite_aug1_bm25_diff_27k_cl100k__seed=0,temperature=0,max_tokens=2000__test/mode_vanillafuzzy", "ollama 70B one-shot diff"),
    # ("claude-3-opus-20240229__swt_bench_lite_aug1_bm25_diff_27k_cl100k__seed=0,temperature=0__test/mode_vanillafuzzy", "opus one-shot diff"),
]:
    ress = []
    def log(res):
        ress.append(res)

    eval_dir, name = approach
    if "=07" in eval_dir:
        measure_coverage_patch_multi(f"evaluation_output/{eval_dir}", log=log, dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test")
    else:
        measure_coverage_patch(f"evaluation_output/{eval_dir}", log=log, dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test")
    applied_cases = [res for res in ress if res.get("no_added_lines", None) is not None]
    no_applied_cases = len(applied_cases)

    good_cases = [res for res in applied_cases if res.get("good_case", -1) == 1 and res["instance_id"] in cmp_pass_cases]
    bad_cases = [res for res in applied_cases if res.get("good_case", -1) == 0]
    r = []
    for cases in [good_cases, bad_cases]:
        pct_lines_covered_after_pred = sum(save_div(res["no_removed_lines_covered_after_pred"] + res["no_added_lines_covered_after_pred"], res["no_removed_lines"] + res["no_added_lines"], 1) for res in cases)
        pct_missed_lines_covered_after_pred = sum(save_div(res["no_removed_missed_lines_covered_after_pred"] + res["no_added_missed_lines_covered_after_pred"], res["no_removed_missed_lines"] + res["no_added_missed_lines"], 1) for res in cases)
        pct_any_line_covered = sum(save_div(bool(res["no_removed_lines_covered_after_pred"] + res["no_added_lines_covered_after_pred"]), bool(res["no_removed_lines"] + res["no_added_lines"]), 1) for res in cases)
        pct_any_missed_line_covered = sum(save_div(bool(res["no_removed_missed_lines_covered_after_pred"] + res["no_added_missed_lines_covered_after_pred"]), bool(res["no_removed_missed_lines"] + res["no_added_missed_lines"]), 1) for res in cases)
        r.extend(str(save_div(x, len(cases), "NA")) for x in (pct_lines_covered_after_pred, pct_missed_lines_covered_after_pred, pct_any_line_covered, pct_any_missed_line_covered))


    print(
        ",".join([
            name,
            *r
        ])
    )
