import json
import pathlib

import fire
from unidiff import PatchSet

from datasets import load_dataset

from metrics.log_parsers import MAP_REPO_TO_PARSER
from metrics.getters import get_file_name_from_lp, get_repo_from_lp, log_path_to_sms, FAIL_TO_NULL, PASS_TO_NULL, \
    NULL_TO_FAIL, NULL_TO_PASS

from measure_coverage_patch import log, BLACKLIST_FULL
from metrics.getters import FAIL_TO_FAIL, FAIL_TO_PASS, PASS_TO_FAIL, PASS_TO_PASS, test_failed, test_passed, TestStatus

from measure_coverage_patch import extract_number_added_tests_from_patch


def count(
    dataset: str = "princeton-nlp/SWE-bench_Lite_oracle",
    split: str = "test",
):
    dataset = load_dataset(dataset)
    nums = []
    for example in dataset[split]:
        if example["instance_id"] in BLACKLIST_FULL:
            continue
        patch = example["test_patch"]
        num_tests = extract_number_added_tests_from_patch(PatchSet(patch))
        nums.append(num_tests)
    print(",".join(map(str, (sum(n > 0 for n in nums), sum(n < 0 for n in nums), sum(n == 0 for n in nums), len(nums)))))

def main():
    print(",Added, Removed, Modified, Total")
    print("SWE-bench_Lite,", end="")
    count(
        dataset="princeton-nlp/SWE-bench_Lite",
        split="test",
    )
    print("SWE-bench,", end="")
    count(
        dataset="princeton-nlp/SWE-bench",
        split="test",
    )



if __name__ == "__main__":
    fire.Fire(main)