import json

from swegraft.seed import Seed
from swegraft.utils.repo import (
    repo_structure,
    get_full_file_paths_and_classes_and_functions,
)
from swegraft.runtime.swalm import SwalmRuntime, get_swalm_image
from swegraft.strategy.migrate.constants import DIFF_EXAMPLES
from swegraft.utils.llm import LLMClient
from swegraft.utils.common import syn_result_dir, wait_coros_with_progress
from swegraft.strategy.migrate.agents.utils import (
    show_project_structure,
    _post_process_multifile_repair,
    fake_git_diff,
)
from rich.console import Console
from multi_swe_bench.harness.test_result import TestStatus, TestResult

console = Console()


# localize the context
TESTGEN_LOCALIZE = """\
Please look through a given issue description and repository structure and provide two list of files related to the issue:
- `source_files`: the files may contains code related to the functionality described in the issue
- `test_files`: the files which should contain the test cases for the functionality described in the issue

--- BEGIN ISSUE ---
{issue}
--- END ISSUE ---

--- BEGIN REPOSITORY STRUCTURE ---
{structure}
--- END REPOSITORY STRUCTURE ---

Only provide the full path and return at most {n} files for each list. 

Respond in the following format, wrapped your results in a markdown python code block with a dictionary with two keys `source_files` and `test_files`.
```python
{{
    "source_files": [
        "most/important/file1.xx",
        "less/important/file2.yy",
        ...
    ],
    "test_files": [
        "most/important/file1.xx",
        "less/important/file2.yy",
        ...
    ]
}}
```

""".strip()

# generate search/replace edits
TESTGEN_EDIT = """We are currently adding unit tests to the avoid the future regression for functionality described in the issue.

--- BEGIN ISSUE ---
{issue}
--- END ISSUE ---

Below are some source code segments related to the functionality described in the issue.

--- BEGIN SOURCE FILES ---
{source_files}
--- END SOURCE FILES ---

Below are some files you can edit to add unit tests.
--- BEGIN TEST FILES ---
{test_files}
--- END TEST FILES ---

Please first localize the code in SOURCE FILES to the functionality described in the issue and \
then generate *SEARCH/REPLACE* edits to test to some of TEST FILES to test the issue.

Every *SEARCH/REPLACE* edit must use this format:
1. The file path
2. The start of search block: <<<<<<< SEARCH
3. A contiguous chunk of lines to search for in the existing source code
4. The dividing line: =======
5. The lines to replace into the source code
6. The end of the replace block: >>>>>>> REPLACE

Here is an example:

```
{diff_example}
```

Please note that the *SEARCH/REPLACE* edit REQUIRES PROPER INDENTATION. If you would like to add the line '        print(x)', you must fully write that out, with all those spaces before the code!
Wrap each *SEARCH/REPLACE* edit in a code block as shown in the example above. If you have multiple *SEARCH/REPLACE* edits, use a separate code block for each one.

Please make sure the tests you add are not too simple and can be passed by the existing code.
"""


class TestGenAgent:
    def __init__(
        self,
        model: str,
        llm_client: LLMClient,
        swalm_run_time: SwalmRuntime,
        config: dict,
    ):
        self.model = model
        self.llm_client = llm_client
        self.config = config
        self.swalm_run_time = swalm_run_time

    async def localize(
        self, issue: str, structure: dict
    ) -> list[tuple[list[str], list[str]]]:
        prompt = TESTGEN_LOCALIZE.format(
            issue=issue,
            structure=show_project_structure(structure),
            n=self.config["max_localize_files"],
        )
        messages = [dict(role="user", content=prompt)]
        response = await self.llm_client.chat_completion_async(
            model=self.model,
            messages=messages,
            temperature=self.config.get("temperature", 0.7),
            max_tokens=self.config["max_tokens"],
            n=self.config["localize_n"],
        )
        candidates = []
        seen_predicted = []
        files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
        files = {path: lines for path, lines in files}
        for choice in response.choices:
            try:
                output = (
                    choice.message.content.split("```python")[1].split("```")[0].strip()
                )
                result = eval(output)
                source_files = result["source_files"][
                    : self.config["max_localize_files"]
                ]
                test_files = result["test_files"][: self.config["max_localize_files"]]
            except Exception as e:
                console.print(f"[TestGen] Error in localize: {e}", style="red")
                continue
            source_files = [file for file in source_files if file in files]
            test_files = [file for file in test_files if file in files]
            if any(
                set(source_files) == set(pred_source_files)
                and set(test_files) == set(pred_test_files)
                for pred_source_files, pred_test_files in seen_predicted
            ):
                continue
            seen_predicted.append((set(source_files), set(test_files)))
            candidates.append((source_files, test_files))
        return candidates

    async def make_file_context(self, files: list[str], file_contents: dict) -> str:
        all_contents: list[str] = []
        for file in files:
            content = file_contents[file]
            content = f"### {file}\n{content}"
            all_contents.append(content)
        return "\n\n".join(all_contents)

    async def gen_edit(
        self,
        issue: str,
        seed: Seed,
        structure: dict,
        src_files: list[str],
        test_files: list[str],
    ) -> list[str]:
        files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
        files = {path: "\n".join(lines) for path, lines in files}
        src_context = await self.make_file_context(src_files, files)
        test_context = await self.make_file_context(test_files, files)
        prompt = TESTGEN_EDIT.format(
            issue=issue,
            source_files=src_context,
            test_files=test_context,
            diff_example=DIFF_EXAMPLES[seed.language],
        )
        messages = [dict(role="user", content=prompt)]
        response = await self.llm_client.chat_completion_async(
            model=self.model,
            messages=messages,
            temperature=self.config.get("temperature", 0.7),
            max_tokens=self.config["max_tokens"],
            n=self.config["edit_n"],
        )
        results = []
        for choice in response.choices:
            try:
                edited_files, new_contents = _post_process_multifile_repair(
                    choice.message.content, files, seed.language
                )
                old_contents = [files[file] for file in edited_files]
                patch = fake_git_diff(edited_files, old_contents, new_contents)
                results.append(patch)
            except Exception as e:
                console.print(f"[TestGen] Error in gen_edit: {e}", style="red")
                continue
        return results

    async def edit_and_run(self, seed: Seed, patch: str):
        async with self.swalm_run_time.session(
            get_swalm_image(seed.source, seed.instance_id)
        ) as session:
            await session.write_file("/test.patch", patch)
            await session.execute_command(
                "git apply /test.patch",
                cwd=seed.workspace,
                timeout=30,
                long_command=False,
            )
            run_log = await seed.run(session, timeout=300)
            run_result = seed.parse_log(run_log)
            return dict(
                patch=patch,
                run_log=run_log,
                run_result=run_result,
            )

    def new_pass_tests(
        self, test_run_result: TestResult, run_result: TestResult
    ) -> list[str]:
        # no modified behaviour of old tests
        for test, status in run_result._tests.items():
            if test_run_result._tests.get(test, TestStatus.NONE) != status:
                return []
        new_tests = []
        for test, status in test_run_result._tests.items():
            # add new pass tests
            if test not in run_result._tests and status == TestStatus.PASS:
                new_tests.append(test)
        return new_tests

    async def run(self, seed: Seed, issue: str) -> dict:
        structure = repo_structure(seed, "migrate")
        files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
        files = {path: "\n".join(lines) for path, lines in files}

        # Loclize
        localize_results = await self.localize(issue, structure)
        console.print(
            f"[TestGen] Generated {len(localize_results)} unique localizations",
            style="green",
        )

        # Generate edits
        coros = [
            self.gen_edit(issue, seed, structure, src_files, test_files)
            for src_files, test_files in localize_results
        ]
        results: list[list[str]] = await wait_coros_with_progress(
            coros, "[TestGen] Generating edits"
        )
        patches = sum(results, [])
        patches = [patch for patch in patches if patch.strip()]
        console.print(f"[TestGen] Generated {len(patches)} edits", style="green")

        # Run testsuites
        run_result_file = syn_result_dir("migrate", seed) / "run.json"
        run_result = TestResult.from_dict(json.loads(run_result_file.read_text()))
        coros = [self.edit_and_run(seed, patch) for patch in patches]
        patch_run_results = await wait_coros_with_progress(
            coros, "[TestGen] Running testsuites after edits"
        )

        # Analyze new tests
        for i, result in enumerate(patch_run_results):
            new_tests = self.new_pass_tests(result["run_result"], run_result)
            patch_run_results[i]["new_tests"] = new_tests
        patch_run_results = [
            result for result in patch_run_results if len(result["new_tests"]) > 0
        ]
        console.print(
            f"[TestGen] Collected {len(patch_run_results)} valid edits", style="green"
        )
        console.print(
            f"[TestGen] New tests: {'\n'.join([str(result['new_tests']) for result in patch_run_results])}, select one with most new tests",
            style="green",
        )
        if len(patch_run_results) == 0:
            return None
        best_result = max(patch_run_results, key=lambda x: len(x["new_tests"]))
        return best_result
