from swegraft.strategy.migrate.models import MigrateTaskSpec
from swegraft.seed import Seed
from swegraft.utils.llm import make_llm_from_provider_config
from swegraft.strategy.migrate.agent import abstract_bug_description
from swegraft.strategy.migrate.agent import ISSUE_DIRECT, ISSUE_ABSTRACT
from swalm.core.utils.common import max_concurrency
import random
import json
import pathlib
from swegraft.strategy.migrate.agents.test import TESTGEN_LOCALIZE, TESTGEN_EDIT
from swegraft.strategy.migrate.agents.issue import (
    ISSUE_IMPLE_LOCALIZE,
    ISSUE_IMPLE_EDIT,
)
from swegraft.strategy.migrate.constants import DIFF_EXAMPLES
from swegraft.runtime.swalm import SwalmRuntime, get_swalm_image
from swegraft.utils.repo import (
    repo_structure,
    get_full_file_paths_and_classes_and_functions,
    workspace,
    create_structure,
)
from swegraft.strategy.migrate.agents.utils import (
    _post_process_multifile_repair,
    show_project_structure,
    fake_git_apply,
)
from multi_swe_bench.harness.test_result import TestResult, TestStatus
from rich.console import Console
import unidiff

from dataclasses import dataclass


console = Console()

REPO_STRUCTURE_FILE = "structure.json"
ISSUE_ABSTRACT_FILE = "abstract_issue.txt"
ISSUE_DIRECT_FILE = "direct_issue.txt"
TESTGEN_LOCALIZE_FILE = "testgen_localize.json"


def save_repo_structure(seed: Seed, seed_result_dir: pathlib.Path):
    file = seed_result_dir / REPO_STRUCTURE_FILE
    if file.exists():
        return
    structure = repo_structure(seed, "migrate")
    with open(file, "w") as f:
        f.write(json.dumps(structure, indent=2))


def make_file_context(files: list[str], file_contents: dict) -> str:
    all_contents: list[str] = []
    for file in files:
        content = file_contents[file]
        content = f"### {file}\n{content}"
        all_contents.append(content)
    return "\n\n".join(all_contents)


### Issue Abstract
async def abstract_issue(
    spec: MigrateTaskSpec,
    provider_config: dict,
    model: str,
    rewrite_config: dict,
    result_dir: pathlib.Path,
):
    abstract_issue_file = result_dir / ISSUE_ABSTRACT_FILE
    direct_issue_file = result_dir / ISSUE_DIRECT_FILE
    if abstract_issue_file.exists() and direct_issue_file.exists():
        return
    llm_client = make_llm_from_provider_config(provider_config)
    direct_issue = ISSUE_DIRECT.format(body=spec.description, diff=spec.patch)
    prompt = ISSUE_ABSTRACT.format(body=spec.description, diff=spec.patch)
    abstract_issue = await abstract_bug_description(
        llm_client,
        prompt,
        model=model,
        max_tokens=rewrite_config["max_tokens"],
        temperature=rewrite_config["temperature"],
    )
    if abstract_issue is not None:
        with open(result_dir / "abstract_issue.txt", "w") as f:
            f.write(abstract_issue)
    with open(result_dir / "direct_issue.txt", "w") as f:
        f.write(direct_issue)


### TestGen
async def testgen_localize(
    model_to_provider: dict,
    rewrite_issue: bool,
    testgen_config: dict,
    result_dir: pathlib.Path,
) -> list[dict] | None:
    result_file = result_dir / TESTGEN_LOCALIZE_FILE
    if result_file.exists():
        return
    model = random.choice(testgen_config["models"])
    provider_config = random.choice(model_to_provider[model])
    llm_client = make_llm_from_provider_config(provider_config)
    structure_file = result_dir.parent.parent / "structure.json"
    structure = json.loads(structure_file.read_text())
    if not rewrite_issue:
        issue_file = result_dir / "direct_issue.txt"
    else:
        issue_file = result_dir / "abstract_issue.txt"
    if not issue_file.exists():
        return
    issue = issue_file.read_text()
    prompt = TESTGEN_LOCALIZE.format(
        issue=issue,
        structure=show_project_structure(structure),
        n=testgen_config["max_localize_files"],
    )
    messages = [dict(role="user", content=prompt)]
    response = await llm_client.chat_completion_async(
        model=model,
        messages=messages,
        temperature=testgen_config.get("temperature", 0.7),
        max_tokens=testgen_config["max_tokens"],
        n=testgen_config["localize_n"],
    )
    if response is None:
        return
    results = []
    seen_predicted = []
    files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
    files = {path: lines for path, lines in files}
    for index, choice in enumerate(response.choices):
        try:
            output = (
                choice.message.content.split("```python")[1].split("```")[0].strip()
            )
            result = eval(output)
            source_files = result["source_files"][
                : testgen_config["max_localize_files"]
            ]
            test_files = result["test_files"][: testgen_config["max_localize_files"]]
        except Exception as e:
            console.print(f"[TestGen] Error in localize: {e}", style="red")
            continue
        source_files = [file for file in source_files if file in files]
        test_files = [file for file in test_files if file in files]
        if any(
            set(source_files) == set(pred_source_files)
            and set(test_files) == set(pred_test_files)
            for pred_source_files, pred_test_files in seen_predicted
        ):
            continue
        seen_predicted.append((set(source_files), set(test_files)))
        results.append(
            dict(
                index=index,
                source_files=source_files,
                test_files=test_files,
                converstaions=messages
                + [dict(role="assistant", content=choice.message.content)],
            )
        )

    with open(result_file, "w") as f:
        json.dump(results, f, indent=2)


async def testgen_edit(
    seed: Seed,
    provider_config: dict,
    model: str,
    rewrite_issue: bool,
    localize_result: dict,
    testgen_config: dict,
    result_dir: pathlib.Path,
) -> list[dict] | None:
    localize_index = localize_result["index"]
    result_file = result_dir / f"testgen_edit_{localize_index}.json"
    if result_file.exists():
        return
    structure_file = result_dir.parent.parent / "structure.json"
    structure = json.loads(structure_file.read_text())
    if not rewrite_issue:
        issue_file = result_dir / "direct_issue.txt"
    else:
        issue_file = result_dir / "abstract_issue.txt"
    issue = issue_file.read_text()
    results = []
    src_files = localize_result["source_files"][: testgen_config["max_localize_files"]]
    test_files = localize_result["test_files"][: testgen_config["max_localize_files"]]
    files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
    files = {path: "\n".join(lines) for path, lines in files}
    src_context = make_file_context(src_files, files)
    test_context = make_file_context(test_files, files)
    prompt = TESTGEN_EDIT.format(
        issue=issue,
        source_files=src_context,
        test_files=test_context,
        diff_example=DIFF_EXAMPLES[seed.language],
    )
    messages = [dict(role="user", content=prompt)]
    llm_client = make_llm_from_provider_config(provider_config)
    response = await llm_client.chat_completion_async(
        model=model,
        messages=messages,
        temperature=testgen_config.get("temperature", 0.7),
        max_tokens=testgen_config["max_tokens"],
        n=testgen_config["edit_n"],
    )
    if response is None:
        return
    results = []
    for index, choice in enumerate(response.choices):
        try:
            results.append(
                dict(
                    localize_index=localize_index,
                    edit_index=index,
                    response=choice.message.content,
                )
            )
        except Exception as e:
            console.print(f"[TestGen] Error in gen_edit: {e}", style="red")
            continue
    with open(result_file, "w") as f:
        json.dump(results, f, indent=2)


async def testgen_edits(
    seed: Seed,
    provider_config: dict,
    model: str,
    rewrite_issue: bool,
    testgen_config: dict,
    result_dir: pathlib.Path,
) -> list[dict] | None:
    localize_result_file = result_dir / TESTGEN_LOCALIZE_FILE
    if not localize_result_file.exists():
        return
    localize_results = json.loads(localize_result_file.read_text())
    for localize_result in localize_results[: testgen_config["localize_n"]]:
        await testgen_edit(
            seed,
            provider_config,
            model,
            rewrite_issue,
            localize_result,
            testgen_config,
            result_dir,
        )


def convert_testgen_edits_to_patches(
    seed: Seed,
    localize_results: list[dict],
    result_dir: pathlib.Path,
):
    result_file = result_dir / "testgen_patches.json"
    if result_file.exists():
        return
    edit_repsonses = []
    # collect edits
    for localize_result in localize_results:
        index = localize_result["index"]
        edit_result_file = result_dir / f"testgen_edit_{index}.json"
        if not edit_result_file.exists() or edit_result_file.read_text().strip() == "":
            continue
        try:
            edit_result = json.loads(edit_result_file.read_text())
        except Exception:
            console.print(
                f"[TestGen] Error in testgen_edits_to_patches: {edit_result_file}",
                style="red",
            )
            continue
        for edit_result in edit_result:
            response = edit_result.get("response", "")
            edit_repsonses.append(
                dict(
                    localize_index=index,
                    edit_index=edit_result["edit_index"],
                    response=response,
                )
            )
    results = []
    for task in edit_repsonses:
        localize_index = task["localize_index"]
        edit_index = task["edit_index"]
        response = task["response"]
        try:
            patch = _post_process_multifile_repair(response, seed)
            results.append(
                dict(
                    localize_index=localize_index,
                    edit_index=edit_index,
                    patch=patch,
                )
            )
        except Exception as e:
            import traceback

            traceback.print_exc()
            console.print(
                f"[TestGen] Error in testgen_edits_to_patches: {e}", style="red"
            )
            continue
    result_file.write_text(json.dumps(results, indent=2))


def new_pass_tests(
    run_result: TestResult,
    test_run_result: TestResult | None,
) -> list[str]:
    if test_run_result is None:
        return []
    # no modified behaviour of old tests
    for test, status in run_result._tests.items():
        if test_run_result._tests.get(test, TestStatus.NONE) != status:
            return []
    new_tests = []
    for test, status in test_run_result._tests.items():
        # add new pass tests
        if test not in run_result._tests and status == TestStatus.PASS:
            new_tests.append(test)
    return new_tests


@max_concurrency(10)
async def file_write(file: pathlib.Path, data: dict):
    with open(file, "w") as f:
        json.dump(data, f, indent=2)


@max_concurrency(10)
async def parse_log(seed: Seed, log: str):
    return seed.parse_log(log)


async def testgen_patch_run(
    seed: Seed,
    runtime: SwalmRuntime,
    patch_result: dict,
    result_dir: pathlib.Path,
    run_result: TestResult,
):
    localize_index = patch_result["localize_index"]
    edit_index = patch_result["edit_index"]
    patch = patch_result["patch"]
    result_file = result_dir / f"testgen_run_{localize_index}_{edit_index}.json"
    if result_file.exists() and result_file.read_text().strip() != "":
        return
    async with runtime.session(
        get_swalm_image(seed.source, seed.instance_id)
    ) as session:
        await session.write_file("/test.patch", patch)
        await session.execute_command(
            "git apply /test.patch", cwd=seed.workspace, timeout=30, long_command=False
        )
        test_run_log = await seed.run(session, timeout=600)
        test_run_result = await parse_log(seed, test_run_log)
    new_tests = new_pass_tests(run_result, test_run_result)
    await file_write(
        result_file,
        dict(
            localize_index=localize_index,
            edit_index=edit_index,
            test_run_log=test_run_log,
            test_run_result=json.loads(test_run_result.json())
            if test_run_result is not None
            else None,
            new_tests=new_tests,
        ),
    )


### Issue Implement
async def issue_implement_localize(
    seed: Seed,
    model_to_provider: dict,
    testgen_patch: dict,
    rewrite_issue: bool,
    result_dir: pathlib.Path,
    buggen_config: dict,
):
    result_file = (
        result_dir
        / f"issue_implement_localize_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}.json"
    )
    if result_file.exists() and result_file.read_text().strip() != "":
        return
    model = random.choice(buggen_config["models"])
    provider_config = random.choice(model_to_provider[model])
    llm_client = make_llm_from_provider_config(provider_config)
    with workspace(seed, "migrate") as ws:
        fake_git_apply(testgen_patch["patch"], ws)
        structure = create_structure(ws)
    if not rewrite_issue:
        issue_file = result_dir / "direct_issue.txt"
    else:
        issue_file = result_dir / "abstract_issue.txt"
    issue = issue_file.read_text()
    prompt = ISSUE_IMPLE_LOCALIZE.format(
        issue=issue,
        structure=show_project_structure(structure),
        n=buggen_config["max_localize_files"],
        testgen_patch=testgen_patch["patch"],
    )
    messages = [dict(role="user", content=prompt)]
    response = await llm_client.chat_completion_async(
        model=model,
        messages=messages,
        temperature=buggen_config.get("temperature", 0.7),
        max_tokens=buggen_config["max_tokens"],
        n=buggen_config["localize_n"],
    )
    results = []
    seen_predicted = []
    files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
    files = {path: lines for path, lines in files}
    for index, choice in enumerate(response.choices):
        try:
            output = (
                choice.message.content.split("```python")[1].split("```")[0].strip()
            )
            located_files = eval(output)
        except Exception as e:
            console.print(f"[TestGen] Error in localize: {e}", style="red")
            continue
        located_files = [file for file in located_files if file in files]
        if any(
            set(located_files) == set(pred_located_files)
            for pred_located_files in seen_predicted
        ):
            continue
        seen_predicted.append(located_files)
        results.append(
            dict(
                index=index,
                files=located_files,
                converstaions=messages
                + [dict(role="assistant", content=choice.message.content)],
            )
        )

    with open(result_file, "w") as f:
        json.dump(results, f, indent=2)


async def issue_implement_edit(
    seed: Seed,
    model: str,
    provider_config: dict,
    rewrite_issue: bool,
    localize_result: dict,
    buggen_config: dict,
    testgen_patch: dict,
    result_dir: pathlib.Path,
) -> list[dict] | None:
    localize_index = localize_result["index"]
    result_file = (
        result_dir
        / f"issue_implement_edit_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}_{localize_index}.json"
    )
    new_tests = testgen_patch["new_tests"]
    if result_file.exists():
        return
    if not rewrite_issue:
        issue_file = result_dir / "direct_issue.txt"
    else:
        issue_file = result_dir / "abstract_issue.txt"
    with workspace(seed, "migrate") as ws:
        fake_git_apply(testgen_patch["patch"], ws)
        structure = create_structure(ws)
    issue = issue_file.read_text()
    results = []
    located_files = localize_result["files"]
    test_edited_files = (
        unidiff.PatchSet(testgen_patch["patch"]).modified_files
        + unidiff.PatchSet(testgen_patch["patch"]).added_files
    )
    test_edited_files = [file.path for file in test_edited_files][
        : buggen_config["max_localize_files"]
    ]
    # dedup
    located_files = list(set(located_files))
    files, _, _ = get_full_file_paths_and_classes_and_functions(structure)
    files = {path: "\n".join(lines) for path, lines in files}
    file_context = make_file_context(located_files, files)
    # remove structure to save memory
    prompt = ISSUE_IMPLE_EDIT.format(
        issue=issue,
        files=file_context,
        testgen_patch=testgen_patch["patch"],
        tests=new_tests,
        diff_example=DIFF_EXAMPLES[seed.language],
    )
    messages = [dict(role="user", content=prompt)]
    llm_client = make_llm_from_provider_config(provider_config)
    response = await llm_client.chat_completion_async(
        model=model,
        messages=messages,
        temperature=buggen_config.get("temperature", 0.7),
        max_tokens=buggen_config["max_tokens"],
        n=buggen_config["edit_n"],
    )
    if response is None:
        return
    results = []
    for index, choice in enumerate(response.choices):
        try:
            results.append(
                dict(
                    localize_index=localize_index,
                    edit_index=index,
                    response=choice.message.content,
                )
            )
        except Exception as e:
            console.print(f"[Issue Implement] Error in gen_edit: {e}", style="red")
            continue
    with open(result_file, "w") as f:
        json.dump(results, f, indent=2)


async def issue_implement_edits(
    seed: Seed,
    provider_config: dict,
    model: str,
    rewrite_issue: bool,
    buggen_config: dict,
    testgen_patch: dict,
    result_dir: pathlib.Path,
):
    localize_result_file = (
        result_dir
        / f"issue_implement_localize_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}.json"
    )
    localize_results = json.loads(localize_result_file.read_text())
    for localize_result in localize_results:
        await issue_implement_edit(
            seed,
            model,
            provider_config,
            rewrite_issue,
            localize_result,
            buggen_config,
            testgen_patch,
            result_dir,
        )


@dataclass
class Test:
    run: TestStatus
    test: TestStatus
    bug: TestStatus


def broken_tests(
    run_result: TestResult,
    test_run_result: TestResult,
    bug_run_result: TestResult,
) -> list[str]:
    all_tests = (
        run_result._tests.keys()
        | bug_run_result._tests.keys()
        | test_run_result._tests.keys()
    )
    tests: dict[str, Test] = {}
    for test_name in all_tests:
        run = run_result._tests.get(test_name, TestStatus.NONE)
        bug = bug_run_result._tests.get(test_name, TestStatus.NONE)
        result = test_run_result._tests.get(test_name, TestStatus.NONE)
        tests[test_name] = Test(run, result, bug)

    # 1. status in test result should be same of run result, or pass if not exits in run
    # permit: PASS->PASS, FAIL->FAIL, SKIP->SKIP, NONE->PASS
    for result in tests.values():
        if result.run != TestStatus.NONE and result.run != result.test:
            print("Test patch change existing test status")
            return []

    # 2. No added test in bug patch, ANY->NONE->ANY
    test_names = set(tests.keys())
    for name, result in tests.items():
        if (
            result.test == TestStatus.NONE
            and result.bug == TestStatus.NONE
            and not any(name.startswith(test_name) for test_name in test_names)
        ):
            return []

    # 3. Bug must be detected, one PASS->PASS->FAIL or NONE->PASS->FAIL is expected
    broken_tests = []
    for name, result in tests.items():
        if result.bug == TestStatus.FAIL and result.test == TestStatus.PASS:
            broken_tests.append(name)
    return broken_tests


def convert_issue_implement_edits_to_patches(
    seed: Seed,
    testgen_patch: dict,
    localize_results: list[dict],
    result_dir: pathlib.Path,
):
    result_file = (
        result_dir
        / f"issue_implement_patches_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}.json"
    )
    if result_file.exists():
        return
    edit_repsonses = []
    # collect edits
    for localize_result in localize_results:
        index = localize_result["index"]
        edit_result_file = (
            result_dir
            / f"issue_implement_edit_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}_{index}.json"
        )
        if not edit_result_file.exists():
            continue
        edit_result = json.loads(edit_result_file.read_text())
        for edit_result in edit_result:
            response = edit_result.get("response", "")
            edit_repsonses.append(
                dict(
                    localize_index=index,
                    edit_index=edit_result["edit_index"],
                    response=response,
                )
            )
    results = []
    for response in edit_repsonses:
        localize_index = response["localize_index"]
        edit_index = response["edit_index"]
        response = response["response"]
        try:
            patch = _post_process_multifile_repair(
                response, seed, testgen_patch["patch"]
            )
            results.append(
                dict(
                    localize_index=localize_index,
                    edit_index=edit_index,
                    patch=patch,
                )
            )
        except Exception as e:
            import traceback

            traceback.print_exc()
            console.print(
                f"[Issue Implement] Error in issue_implement_edits_to_patches: {e}",
                style="red",
            )
            continue
    result_file.write_text(json.dumps(results, indent=2))


async def issue_implement_patch_run(
    seed: Seed,
    runtime: SwalmRuntime,
    testgen_patch: dict,
    issue_implement_patch: dict,
    result_dir: pathlib.Path,
    run_result: TestResult,
):
    result_file = (
        result_dir
        / f"issue_implement_run_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}_{issue_implement_patch['localize_index']}_{issue_implement_patch['edit_index']}.json"
    )
    testgen_run_result_file = (
        result_dir
        / f"testgen_run_{testgen_patch['localize_index']}_{testgen_patch['edit_index']}.json"
    )
    if not testgen_run_result_file.exists():
        return
    with open(testgen_run_result_file, "r") as f:
        testgen_run_result_dict = json.load(f).get("test_run_result", None)
        if testgen_run_result_dict is None:
            return
    testgen_run_result = TestResult.from_dict(testgen_run_result_dict)
    if result_file.exists():
        with open(result_file, "r") as f:
            result = json.load(f)
        if result["issue_implement_run_result"] is not None:
            result["issue_implement_broken_tests"] = broken_tests(
                run_result,
                testgen_run_result,
                TestResult.from_dict(result["issue_implement_run_result"]),
            )
            with open(result_file, "w") as f:
                json.dump(result, f, indent=2)
        return
    async with runtime.session(
        get_swalm_image(seed.source, seed.instance_id)
    ) as session:
        await session.write_file("/testgen.patch", testgen_patch["patch"])
        await session.write_file(
            "/issue_implement.patch", issue_implement_patch["patch"]
        )
        await session.execute_command(
            "git apply -v /testgen.patch",
            cwd=seed.workspace,
            long_command=False,
        )
        await session.execute_command(
            "git apply -v /issue_implement.patch",
            cwd=seed.workspace,
            long_command=False,
        )
        issue_implement_run_log = await seed.run(session, timeout=300)
        issue_implement_run_result = seed.parse_log(issue_implement_run_log)
        if issue_implement_run_result is None:
            return
        issue_implement_run_result_dict = json.loads(issue_implement_run_result.json())
    with open(result_file, "w") as f:
        json.dump(
            dict(
                testgen_localize_index=testgen_patch["localize_index"],
                testgen_edit_index=testgen_patch["edit_index"],
                issue_implement_localize_index=issue_implement_patch["localize_index"],
                issue_implement_edit_index=issue_implement_patch["edit_index"],
                issue_implement_run_log=issue_implement_run_log,
                issue_implement_run_result=issue_implement_run_result_dict,
                issue_implement_broken_tests=broken_tests(
                    run_result, testgen_run_result, issue_implement_run_result
                ),
            ),
            f,
            indent=2,
        )
