"""
Generate no-patch oracle and golden patch oracle for SWE-bench
"""
import json
import os.path
import pathlib
import subprocess
from datasets import load_dataset, Dataset, DatasetDict

from unidiff import PatchSet

# Load the dataset
dataset = load_dataset("princeton-nlp/SWE-bench_Lite_oracle")

tmp_dir = "/tmp/swe_bench_repos_seperate"

new_diff_prompt = """
Please generate test cases that check whether an implemented solution
resolves the issue of the user (at the top, within <issue/> brackets).
Present the test cases as a diff (custom format, explained below).

The general format of a diff is as follows.
```custom-diff
diff
<path/filename>
< "rewrite" or "insert" >
< rough line number / EOF / BOF >
< insert function that should be added or rewritten >
end diff
< repeat blocks of diff as necessary >
```
Insertion can only be done at the end or beginning of the file, indicated by EOF or BOF respectively.

As an example for a diff, consider the following two versions of the same file, once before and once after a change.
The original version of the file was as follows.
[start of demo/test_file.py]
1 def test_euclidean(a, b):
2     assert euclidean(0, 0) == 0
3     assert euclidean(0, 1) == 1
4     assert euclidean(1, 0) == 1
5     assert euclidean(1, 1) == 1
6
7 @pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)])
8 def test_gcd(a, b):
9     assert gcd(a, b) == expected
10
[end of demo/file.py]

The diff for fix in function euclidean and adds the function gcd is as follows.
This diff changes the first file into the second file.
```custom-diff
diff
demo/file.py
rewrite
1
def test_euclidean(a, b):
    assert euclidean(0, 0) == 0
    assert euclidean(0, 1) == 1
    assert euclidean(1, 0) == 1
    assert euclidean(1, 1) == 1
    assert euclidean(100, 10) == 10
end diff
diff
demo/file.py
insert
EOF
@ pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1), (100, 10, 10)])
def test_lcm(a, b):
    assert lcm(a, b) == expected
end diff
```

The new version of the file is as follows.
[start of demo/file.py]
1 def test_euclidean(a, b):
2     assert euclidean(0, 0) == 0
3     assert euclidean(0, 1) == 1
4     assert euclidean(1, 0) == 1
5     assert euclidean(1, 1) == 1
6     assert euclidean(100, 10) == 10
7
8 @pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)])
9 def test_gcd(a, b):
10     assert gcd(a, b) == expected
11
12 @pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1), (100, 10, 10)])
13 def test_lcm(a, b):
14     assert lcm(a, b) == expected
15
[end of demo/file.py]

As you can see, you need to indicate the approximate line numbers, function name and the path and file name you want to change,
but there can be as many independent blocks of changes as you need. You may also apply changes to several files.
Apply as much reasoning as you please and see necessary. The format of the solution is fixed and has to follow the custom diff format.
Make sure to implement only test cases and don't try to fix the issue itself.
"""

patch_prompt = """
The following patch has been proposed to fix the issue described in the user issue (in <issue/> brackets).
The patch might give you a hint on how to write a covering test for the issue, but you should not assume that the patch is correct.
It might be that the provided patch is not correct, so your test should check whether the patch resolves the issue.
{}
"""

with open("failing_patches.json") as f:
    failing_patches = json.load(f)

for oracle_consistent in [False, True]:
    splits = {}
    for split in ["test"]:
        count = 0
        num_new_funs = 0
        total_count = 0
        new_examples = []
        for i, example in enumerate(dataset[split]):
            files = []
            repo_name = example["repo"].split("/")[-1]
            test_patch = example["test_patch"]
            if not os.path.isdir(f"{tmp_dir}/{repo_name}"):
                pathlib.Path(f"{tmp_dir}/{repo_name}").mkdir(parents=True, exist_ok=True)
                subprocess.run(["bash", "-c", f"cd {tmp_dir} && git clone https://github.com/{example['repo']}"],
                               check=True, capture_output=True)
            subprocess.run(["bash", "-c", f"cd {tmp_dir}/{repo_name} && git reset && git stash && git clean -fdx && git checkout {example['base_commit']}"],
                           check=True, capture_output=True)
            # inject proposed patch
            golden_patch = failing_patches.get(example["instance_id"])
            if golden_patch is None:
                continue
            if oracle_consistent:
                patchset = PatchSet(golden_patch)
                changed_files = [file.source_file[2:] for file in patchset if not file.is_added_file]
                for file in changed_files:
                    try:
                        with open(f"{tmp_dir}/{repo_name}/{file}") as f:
                            source_code = f.readlines()
                        original_file_name = file
                        source_code = [f"{j + 1} {l[:-1]}" for j, l in enumerate(source_code)]
                        files.append((original_file_name, source_code))
                    except:
                        continue

            # find the test files that were changed
            patchset = PatchSet(test_patch)
            changed_files = [file.source_file[2:] for file in patchset if not file.is_added_file]
            for file in changed_files:
                with open(f"{tmp_dir}/{repo_name}/{file}") as f:
                    source_code = f.readlines()
                original_file_name = file
                source_code = [f"{j + 1} {l[:-1]}" for j, l in enumerate(source_code)]
                files.append((original_file_name, source_code))
            orig_text = example["text"].splitlines()
            new_text = orig_text

            line_of_proposed_patch_prompt = [i for i, l in enumerate(new_text) if l.startswith("</issue>")][-1]+1
            new_text = new_text[:line_of_proposed_patch_prompt] + [patch_prompt.format(golden_patch)] + new_text[line_of_proposed_patch_prompt:]

            # inject relevant test code
            if oracle_consistent:
                # replace the golden oracle files
                line_to_inject = [i for i, l in enumerate(new_text) if l.startswith("[start of ")][0]
            else:
                # append to golden oracle files
                line_to_inject = [i for i, l in enumerate(new_text) if l.startswith("[end of ")][-1]+1
            prev_text = new_text
            new_text = prev_text[:line_to_inject]
            for filename, file in files:
                new_text.append(f"[start of {filename}]")
                new_text.extend(file)
                new_text.append(f"[end of {filename}]")
            new_text.extend(prev_text[line_to_inject:])
            new_text[0] = f"The following text contains a user issue (in <issue/> brackets) and a proposed patch (in <patch/> brackets posted at a repository. Further, you are provided with file contents of several files in the repository that contain relevant code (in <code> brackets). It may be necessary to use code from third party dependencies or files not contained in the attached documents however. Your task is to identify the issue and implement a test case that verifies a proposed solution to this issue. More details at the end of this text."

            # inject new custom diff prompt
            line_of_diff_prompt = [i for i, l in enumerate(new_text) if l.startswith("</code>")][-1]+1
            new_text = new_text[:line_of_diff_prompt]
            new_text = "\n".join(new_text)
            new_text += new_diff_prompt
            new_example = {
                **example,
                "text": new_text,
                "test_patch": "\n".join(example["patch"].splitlines()[1:-1]),
                "patch": "<patch>\n" + example["test_patch"] + "\n</patch>",
            }
            new_examples.append(new_example)
        splits[split] = Dataset.from_list(new_examples)
    ds = DatasetDict(splits)
    ds.save_to_disk(f"./datasets/swt_bench_aug1_oracle_failing_patch{'_consistent' if oracle_consistent else ''}")

