"""
Generate no-patch oracle and golden patch oracle for SWE-bench
"""

import os.path
import pathlib
import subprocess
from datasets import load_dataset, Dataset, DatasetDict

from unidiff import PatchSet

# Load the dataset
dataset = load_dataset("princeton-nlp/SWE-bench_Lite_oracle")

tmp_dir = "/tmp/swe_bench_repos_seperate"

new_diff_prompt = """
Please generate test cases that check whether an implemented solution
resolves the issue of the user (at the top, within <issue/> brackets).
Present the test cases as a diff (custom format, explained below).

The general format of a diff is as follows.
```custom-diff
diff
<path/filename>
< "rewrite" or "insert" >
< rough line number / EOF / BOF >
< insert function that should be added or rewritten >
end diff
< repeat blocks of diff as necessary >
```
Insertion can only be done at the end or beginning of the file, indicated by EOF or BOF respectively.

As an example for a diff, consider the following two versions of the same file, once before and once after a change.
The original version of the file was as follows.
[start of demo/test_file.py]
1 def test_euclidean(a, b):
2     assert euclidean(0, 0) == 0
3     assert euclidean(0, 1) == 1
4     assert euclidean(1, 0) == 1
5     assert euclidean(1, 1) == 1
6
7 @pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)])
8 def test_gcd(a, b):
9     assert gcd(a, b) == expected
10
[end of demo/file.py]

The diff for fix in function euclidean and adds the function gcd is as follows.
This diff changes the first file into the second file.
```custom-diff
diff
demo/file.py
rewrite
1
def test_euclidean(a, b):
    assert euclidean(0, 0) == 0
    assert euclidean(0, 1) == 1
    assert euclidean(1, 0) == 1
    assert euclidean(1, 1) == 1
    assert euclidean(100, 10) == 10
end diff
diff
demo/file.py
insert
EOF
@ pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1), (100, 10, 10)])
def test_lcm(a, b):
    assert lcm(a, b) == expected
end diff
```

The new version of the file is as follows.
[start of demo/file.py]
1 def test_euclidean(a, b):
2     assert euclidean(0, 0) == 0
3     assert euclidean(0, 1) == 1
4     assert euclidean(1, 0) == 1
5     assert euclidean(1, 1) == 1
6     assert euclidean(100, 10) == 10
7
8 @pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)])
9 def test_gcd(a, b):
10     assert gcd(a, b) == expected
11
12 @pytest.mark.parametrize("a, b, expected", [(0, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1), (100, 10, 10)])
13 def test_lcm(a, b):
14     assert lcm(a, b) == expected
15
[end of demo/file.py]

As you can see, you need to indicate the approximate line numbers, function name and the path and file name you want to change,
but there can be as many independent blocks of changes as you need. You may also apply changes to several files.
Apply as much reasoning as you please and see necessary. The format of the solution is fixed and has to follow the custom diff format.
Make sure to implement only test cases and don't try to fix the issue itself.
"""

patch_prompt = """
The following patch has been proposed to fix the issue described in the user issue (in <issue/> brackets).
The patch might give you a hint on how to write a covering test for the issue, but you should not assume that the patch is correct.
It might be that the provided patch is not correct, so your test should check whether the patch resolves the issue.
{}
"""

for patch in [False, True]:
    splits = {}
    for split in ["dev", "test"]:
        count = 0
        num_new_funs = 0
        total_count = 0
        new_examples = []
        for i, example in enumerate(dataset[split]):
            files = []
            repo_name = example["repo"].split("/")[-1]
            test_patch = example["test_patch"]
            if not os.path.isdir(f"{tmp_dir}/{repo_name}"):
                pathlib.Path(f"{tmp_dir}/{repo_name}").mkdir(parents=True, exist_ok=True)
                subprocess.run(["bash", "-c", f"cd {tmp_dir} && git clone https://github.com/{example['repo']}"],
                               check=True, capture_output=True)
            subprocess.run(["bash", "-c", f"cd {tmp_dir}/{repo_name} && git reset && git stash && git clean -fdx && git checkout {example['base_commit']}"],
                           check=True, capture_output=True)
            # find the files that were changed
            patchset = PatchSet(test_patch)
            changed_files = [file.source_file[2:] for file in patchset if not file.is_added_file]
            for file in changed_files:
                with open(f"{tmp_dir}/{repo_name}/{file}") as f:
                    source_code = f.readlines()
                original_file_name = file
                source_code = [f"{j + 1} {l[:-1]}" for j, l in enumerate(source_code)]
                files.append((original_file_name, source_code))
            orig_text = example["text"].splitlines()
            new_text = orig_text

            # inject proposed patch
            if patch:
                golden_patch = example["patch"]
                line_of_proposed_patch_prompt = [i for i, l in enumerate(new_text) if l.startswith("</issue>")][-1]+1
                new_text = new_text[:line_of_proposed_patch_prompt] + [patch_prompt.format(golden_patch)] + new_text[line_of_proposed_patch_prompt:]

            # inject relevant test code
            line_to_inject = [i for i, l in enumerate(new_text) if l.startswith("[end of ")][-1]+1
            prev_text = new_text
            new_text = prev_text[:line_to_inject]
            for filename, file in files:
                new_text.append(f"[start of {filename}]")
                new_text.extend(file)
                new_text.append(f"[end of {filename}]")
            new_text.extend(prev_text[line_to_inject:])
            new_text[0] = f"The following text contains a user issue (in <issue/> brackets) {'and a proposed patch (in <patch/> brackets)' if patch else ''} posted at a repository. Further, you are provided with file contents of several files in the repository that contain relevant code (in <code> brackets). It may be necessary to use code from third party dependencies or files not contained in the attached documents however. Your task is to identify the issue and implement a test case that verifies a proposed solution to this issue. More details at the end of this text."

            # inject new custom diff prompt
            line_of_diff_prompt = [i for i, l in enumerate(new_text) if l.startswith("</code>")][-1]+1
            new_text = new_text[:line_of_diff_prompt]
            new_text = "\n".join(new_text)
            new_text += new_diff_prompt
            new_example = {
                **example,
                "text": new_text,
                "test_patch": "\n".join(example["patch"].splitlines()[1:-1]),
                "patch": "<patch>\n" + example["test_patch"] + "\n</patch>",
            }
            new_examples.append(new_example)
        splits[split] = Dataset.from_list(new_examples)
    ds = DatasetDict(splits)
    ds.save_to_disk(f"./datasets/swt_bench_aug1_oracle{'_patch' if patch else ''}")

