import json

from swegraft.utils.common import wait_coros_with_progress, syn_result_dir
from swegraft.utils.repo import workspace, create_structure
from swegraft.runtime.swalm import get_swalm_image, SwalmRuntime
from swegraft.seed import Seed
from swegraft.utils.repo import get_full_file_paths_and_classes_and_functions
from swegraft.strategy.migrate.agents.utils import (
    show_project_structure,
    _post_process_multifile_repair,
    fake_git_diff,
    fake_git_apply,
)
from swegraft.strategy.migrate.constants import DIFF_EXAMPLES
from swegraft.utils.llm import LLMClient

from rich.console import Console

from multi_swe_bench.harness.test_result import TestResult, TestStatus

from dataclasses import dataclass


@dataclass
class Test:
    run: TestStatus
    test: TestStatus
    bug: TestStatus


console = Console()


ISSUE_IMPLE_LOCALIZE = """\
Please look through a given issue description, repository structure, a patch related to test the issue and provide a list of files related to the issue

Below is the issue description and repository structure.
--- BEGIN ISSUE ---
{issue}
--- END ISSUE ---

Below is the repository structure.
--- BEGIN REPOSITORY STRUCTURE ---
{structure}
--- END REPOSITORY STRUCTURE ---

Below is the patch applied to the repository to test the issue.
--- BEGIN TEST PATCH ---
{testgen_patch}
--- END TEST PATCH ---

Only provide the full path and return at most {n} files. 

Respond in the following format, wrapped your results in a markdown python code block with a list of files.
```python
[
    "most/important/file1.xx",
    "less/important/file2.yy",
    ...
]
```

""".strip()


ISSUE_IMPLE_EDIT = """We are currently implementing the issue described in the following issue description.

--- BEGIN ISSUE ---
{issue}
--- END ISSUE ---

Below are some code segments related to the issue.

--- BEGIN FILES---
{files}
--- END FILES---

Below is the patch applied to the repository to test the issue, please DO NOT modify any test code or test files.
--- BEGIN TEST PATCH ---
{testgen_patch}
--- END TEST PATCH ---


Here is the list of testcases related to the issue.
--- BEGIN TESTS ---
{tests}
--- END TESTS ---


Please first localize the related source code based on the issue description, and then generate *SEARCH/REPLACE* edits to re-implement the issue via breaking the tests in the TESTS section.
DO NOT modify any test code or test files, you should only modify the non-test files and code related to the issue.

Every *SEARCH/REPLACE* edit must use this format:
1. The file path
2. The start of search block: <<<<<<< SEARCH
3. A contiguous chunk of lines to search for in the existing source code
4. The dividing line: =======
5. The lines to replace into the source code
6. The end of the replace block: >>>>>>> REPLACE



Here is an example:

```
{diff_example}
```

Please note that the *SEARCH/REPLACE* edit REQUIRES PROPER INDENTATION. If you would like to add the line '        print(x)', you must fully write that out, with all those spaces before the code!
Wrap each *SEARCH/REPLACE* edit in a code block as shown in the example above. If you have multiple *SEARCH/REPLACE* edits, use a separate code block for each one.

Tips about the issue-implementing task:
- It should break the tests in the TESTS section.
- It should not cause compilation errors.
- It should not be a syntax error.
- It should be subtle and challenging to detect.
- It should not modify the function signature.
- It should not modify the documentation significantly.
- Please DO NOT INCLUDE COMMENTS IN THE CODE indicating the bug location or the bug itself.
- Please DO NOT modify the test code or test files.
- Please DO NOT modify the files modifed by TEST PATCH.
"""


class IssueImplementAgent:
    def __init__(
        self,
        model: str,
        llm_client: LLMClient,
        swalm_run_time: SwalmRuntime,
        config: dict,
    ):
        self.model = model
        self.llm_client = llm_client
        self.swalm_run_time = swalm_run_time
        self.config = config

    async def make_file_context(self, files: list[str], file_contents: dict) -> str:
        all_contents: list[str] = []
        for file in files:
            content = file_contents[file]
            content = f"### {file}\n{content}"
            all_contents.append(content)
        return "\n\n".join(all_contents)

    async def localize(
        self, issue: str, testgen_patch: str, structure: dict
    ) -> list[list[str]]:
        prompt = ISSUE_IMPLE_LOCALIZE.format(
            issue=issue,
            structure=show_project_structure(structure),
            testgen_patch=testgen_patch,
            n=self.config["max_localize_files"],
        )
        messages = [dict(role="user", content=prompt)]
        response = await self.llm_client.chat_completion_async(
            model=self.model,
            messages=messages,
            temperature=self.config.get("temperature", 0.7),
            max_tokens=self.config["max_tokens"],
            n=self.config["localize_n"],
        )
        results = []
        seen_predicted = []
        files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
        files = {path: "\n".join(lines) for path, lines in files}
        for choice in response.choices:
            try:
                output = (
                    choice.message.content.split("```python")[1].split("```")[0].strip()
                )
                localted_files = eval(output)
            except Exception as e:
                console.print(f"[IssueImplement] Error in localize: {e}", style="red")
                continue
            if any(
                set(files.keys()) == set(pred_files) for pred_files in seen_predicted
            ):
                continue
            localted_files = [file for file in localted_files if file in files.keys()]
            seen_predicted.append(list(set(localted_files)))
            results.append(localted_files)
        return results

    async def gen_edit(
        self,
        issue: str,
        seed: Seed,
        structure: dict,
        localize_files: list[str],
        testgen_patch: str,
    ) -> list[str]:
        """
        Generate Patches
        """
        files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
        files = {path: "\n".join(lines) for path, lines in files}
        file_context = await self.make_file_context(localize_files, files)
        prompt = ISSUE_IMPLE_EDIT.format(
            issue=issue,
            files=file_context,
            testgen_patch=testgen_patch,
            diff_example=DIFF_EXAMPLES[seed.language],
        )
        messages = [dict(role="user", content=prompt)]
        response = await self.llm_client.chat_completion_async(
            model=self.model,
            messages=messages,
            temperature=self.config.get("temperature", 0.7),
            max_tokens=self.config["max_tokens"],
            n=self.config["edit_n"],
        )
        results = []
        for choice in response.choices:
            try:
                edited_files, new_contents = _post_process_multifile_repair(
                    choice.message.content, files, seed.language
                )
            except Exception as e:
                console.print(f"[IssueImplement] Error in gen_edit: {e}", style="red")
                continue
            old_contents = [files[file] for file in edited_files]
            patch = fake_git_diff(edited_files, old_contents, new_contents)
            results.append(patch)
        return results

    async def edit_and_run(
        self,
        seed: Seed,
        test_patch: str,
        issue_patch: str,
    ):
        async with self.swalm_run_time.session(
            get_swalm_image(seed.source, seed.instance_id)
        ) as session:
            # First apply testgen patch
            testgen_patch = "/test.patch"
            await session.write_file(testgen_patch, test_patch)
            await session.execute_command(
                f"git apply {testgen_patch}",
                cwd=seed.workspace,
                timeout=30,
                long_command=False,
            )
            # Then apply issue implement patch
            issue_impl_patch = "/issue.patch"
            await session.write_file(issue_impl_patch, issue_patch)
            await session.execute_command(
                f"git apply {issue_impl_patch}",
                cwd=seed.workspace,
                timeout=30,
                long_command=False,
            )
            # Run and parse
            run_log = await seed.run(session, timeout=300)
            run_result = seed.parse_log(run_log)
            return dict(
                patch=issue_patch,
                run_log=run_log,
                run_result=run_result,
            )

    def broken_tests(
        self,
        run_result: TestResult,
        test_run_result: TestResult,
        bug_run_result: TestResult,
    ) -> int:
        all_tests = (
            run_result._tests.keys()
            | bug_run_result._tests.keys()
            | test_run_result._tests.keys()
        )
        tests: dict[str, Test] = {}
        for test_name in all_tests:
            run = run_result._tests.get(test_name, TestStatus.NONE)
            bug = bug_run_result._tests.get(test_name, TestStatus.NONE)
            result = test_run_result._tests.get(test_name, TestStatus.NONE)
            tests[test_name] = Test(run, result, bug)

        # 0. tests should able be run, a fast check
        if bug_run_result.all_count == 0 or test_run_result.all_count == 0:
            return []

        # 1. status in test result should be same of run result, or pass if not exits in run
        # permit: PASS->PASS, FAIL->FAIL, SKIP->SKIP, NONE->PASS
        for result in tests.values():
            if result.run != TestStatus.NONE and result.run != result.test:
                return []
            # NONE->NONE is not allowed, it means bug patch introduces new test
            if result.run == TestStatus.NONE and result.test != TestStatus.PASS:
                return []

        # 2. No added test in bug patch, ANY->NONE->ANY
        for result in tests.values():
            if result.test == TestStatus.NONE:
                return []

        # 3. Bug must be detected, one PASS->PASS->FAIL or NONE->PASS->FAIL is expected
        breaked_tests = {}
        for name, result in tests.items():
            if result.bug == TestStatus.FAIL and result.test == TestStatus.PASS:
                breaked_tests[name] = result
        return breaked_tests

    async def run(
        self, seed: Seed, issue: str, testgen_patch: str, test_run_result: TestResult
    ) -> dict:
        # Apply new tests to content
        with workspace(seed, "migrate") as ws:
            fake_git_apply(testgen_patch, ws)
            structure = create_structure(ws)
        # Localize
        localize_results = await self.localize(issue, testgen_patch, structure)
        console.print(
            f"Generated {len(localize_results)} localize results", style="green"
        )
        coros = [
            self.gen_edit(issue, seed, structure, localize_files, testgen_patch)
            for localize_files in localize_results
        ]

        # Generate edits
        results: list[list[str]] = await wait_coros_with_progress(
            coros, "Generating edits"
        )
        patches = sum(results, [])
        patches = [patch for patch in patches if patch.strip()]
        console.print(f"Generated {len(patches)} unique patches", style="green")

        # Run testsuites after edits
        coros = [self.edit_and_run(seed, testgen_patch, patch) for patch in patches]
        patch_run_results: list[dict] = await wait_coros_with_progress(
            coros, "Running testsuites after edits"
        )

        # Analyze breaked tests
        run_result_file = syn_result_dir("migrate", seed) / "run.json"
        run_result = TestResult.from_dict(json.loads(run_result_file.read_text()))
        for i, result in enumerate(patch_run_results):
            breaked_tests = self.broken_tests(
                run_result, test_run_result, result["run_result"]
            )
            patch_run_results[i]["breaked_tests"] = breaked_tests

        # Filter out invalid results
        patch_run_results = [
            result for result in patch_run_results if len(result["breaked_tests"]) > 0
        ]
        console.print(
            f"[IssueImplement] Collected {len(patch_run_results)} valid edits",
            style="green",
        )
        console.print(
            f"[IssueImplement] Break tests: {'\n'.join([result['breaked_tests'] for result in patch_run_results])}"
        )
        console.print(
            "[IssueImplement] select the best result with most breaked tests",
            style="green",
        )
        if len(patch_run_results) == 0:
            return None
        best_result = max(patch_run_results, key=lambda x: len(x["breaked_tests"]))
        return best_result
