import itertools
from collections import defaultdict

from datasets import load_from_disk

from measure_coverage_patch import main as measure_coverage_patch, save_div, BLACKLIST
from measure_coverage_patch_multi import main as measure_coverage_patch_multi

dataset = "./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k"

dataset = load_from_disk(dataset)["test"]
instances = [x["instance_id"] for x in dataset if x["instance_id"] not in BLACKLIST]
resolved_by_approach = defaultdict(list)
print(len(instances))
names = ["LIBRO", "AutoCodeRover", "SWE-Agent"]

for approach in [
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=1,temperature=07__test/mode_custom", "LIBRO",
     "gpt-4-1106-preview__libro_gpt-4-1106-preview__swt_bench_lite_aug1__test__test.jsonl"),
    ("acr_swt_bench_lite/mode_vanilla", "AutoCodeRover"),
    ("swe-agent-demo3__swt_bench_lite__test/mode_vanilla", "SWE-Agent"),
]:
    eval_dir = approach[0]
    name = approach[1]
    if "=07" in eval_dir:
        libro_dir = approach[2]
        ress = measure_coverage_patch_multi(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", seeds="1,2,3,4,5", setting="IDEAL" if "ideal" in name else "LIBRO", libro_inference_results=f"inference_output/{libro_dir}")
    else:
        ress = measure_coverage_patch(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", fuzzy=False)

    for res in ress:
        if res.get("instance_id") in BLACKLIST:
            continue
        if res.get("good_case"):
            resolved_by_approach[name].append(res["instance_id"])

print("Not solved by any approach", len(set(instances) - set(resolved_by_approach["LIBRO"]) - set(resolved_by_approach["AutoCodeRover"]) - set(resolved_by_approach["SWE-Agent"])))
print("Solved by all approaches", len(set(resolved_by_approach["LIBRO"]) & set(resolved_by_approach["AutoCodeRover"]) & set(resolved_by_approach["SWE-Agent"])))
for name in names:
    print("Solved only by", name, len(set(resolved_by_approach[name]) - (set().union(*(resolved_by_approach[n] for n in names if n != name)))))
for n1,n2 in itertools.product(names, repeat=2):
    if n1 >= n2:
        continue
    n3 = [n for n in names if n != n1 and n != n2][0]
    print("Solved only by", n1, "and", n2, len((set(resolved_by_approach[n1]) & set(resolved_by_approach[n2]) - set(resolved_by_approach[n3])) ))