import json
import math
from collections import defaultdict

from cachier import cachier
from datasets import load_from_disk

from measure_coverage_patch import main as measure_coverage_patch, save_div, BLACKLIST
from measure_coverage_patch_multi import main as measure_coverage_patch_multi
import tiktoken

tokenizer = tiktoken.encoding_for_model("gpt-4")

dataset = "./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k"


@cachier()
def instance_issue_lens(dataset):
    instance_issue_len_per_id = {}
    for example in load_from_disk(dataset)["test"]:
        if example["instance_id"] in BLACKLIST:
            continue
        issue = example["text"].split("<issue>")[1].split("</issue>")[0]
        instance_issue_len_per_id[example["instance_id"]] = len(tokenizer.encode(issue))
    return instance_issue_len_per_id


print("bucket sizes")
buckets = [100, 200, 500, float("inf")]
print(",".join(f"<= {bucket}" for bucket in buckets))

instance_issue_len = instance_issue_lens(dataset)
repo_lens = defaultdict(list)
for instance_id, issue_len in instance_issue_len.items():
    repo = instance_id.split("_")[0]
    repo_lens[repo].append(issue_len)


instance_by_len = defaultdict(list)
global_bucket_sizes = [0 for _ in buckets]
for instance_id, issue_len in instance_issue_len.items():
    instance_by_len[issue_len].append(instance_id)
    for i, bucket in enumerate(buckets):
        if issue_len <= bucket:
            global_bucket_sizes[i] += 1
            break
print(",".join(map(str, global_bucket_sizes)))

print("repo avg len")
repos = sorted(repo_lens.keys())
print("," + ",".join(repos))
print("avg len" + "," + ",".join(str(sum(repo_lens[repo])/len(repo_lens[repo])) for repo in repos))


print("resolved")
print(",\leq 100,\leq 200,\leq 500,> 500")
for approach in [
    # ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_diff_27k_cl100k__seed=0,temperature=0__test/mode_vanillafuzzy", "ZeroShot"),
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=0,temperature=0__test/mode_custom", "ZeroShotPlus"),
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=1,temperature=07__test/mode_custom", "LIBRO",
     "gpt-4-1106-preview__libro_gpt-4-1106-preview__swt_bench_lite_aug1__test__test.jsonl"),
    ("acr_swt_bench_lite/mode_vanilla", "AutoCodeRover"),
    ("swe-agent-demo3__swt_bench_lite__test/mode_vanilla", "SWE-Agent"),
]:
    eval_dir = approach[0]
    name = approach[1]
    if "=07" in eval_dir:
        libro_dir = approach[2]
        ress = measure_coverage_patch_multi(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", seeds="1,2,3,4,5", setting="IDEAL" if "ideal" in name else "LIBRO", libro_inference_results=f"inference_output/{libro_dir}")
    else:
        ress = measure_coverage_patch(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", fuzzy=False)

    local_buckets = [0 for _ in buckets]
    for res in ress:
        issue_len = instance_issue_len[res["instance_id"]]
        for i, bucket in enumerate(buckets):
            if issue_len <= bucket:
                local_buckets[i] += res.get("good_case", 0)
                break

    print(name, end=",")
    print(",".join(map(str, (100*lb/gp for lb, gp in zip(local_buckets, global_bucket_sizes)))))

print("applied")
print(",\leq 100,\leq 200,\leq 500,> 500")

for approach in [
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_diff_27k_cl100k__seed=0,temperature=0__test/mode_vanillafuzzy", "ZeroShot"),
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=0,temperature=0__test/mode_custom", "ZeroShotPlus"),
    ("gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=1,temperature=07__test/mode_custom", "LIBRO",
     "gpt-4-1106-preview__libro_gpt-4-1106-preview__swt_bench_lite_aug1__test__test.jsonl"),
    ("acr_swt_bench_lite/mode_vanilla", "AutoCodeRover"),
    ("swe-agent-demo3__swt_bench_lite__test/mode_vanilla", "SWE-Agent"),
]:
    eval_dir = approach[0]
    name = approach[1]
    if "=07" in eval_dir:
        libro_dir = approach[2]
        ress = measure_coverage_patch_multi(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", seeds="1,2,3,4,5", setting="IDEAL" if "ideal" in name else "LIBRO", libro_inference_results=f"inference_output/{libro_dir}")
    else:
        ress = measure_coverage_patch(f"evaluation_output/{eval_dir}", dataset="./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k", split="test", fuzzy=False)

    local_buckets_applied = [0 for _ in buckets]
    for res in ress:
        issue_len = instance_issue_len[res["instance_id"]]
        for i, bucket in enumerate(buckets):
            if issue_len <= bucket:
                local_buckets_applied[i] += res.get("message") is None
                break

    print(name, end=",")
    print(",".join(map(str, (100*lb/gp for lb, gp in zip(local_buckets_applied, global_bucket_sizes)))))



