import re

import agent.templates as templates
from agent.config import args, logger, reasoning_model
from agent.utils import AgentException, agentic_loop, clean_code
from cwes import get_cwe_by_id
from models import Conversation, Response


def parse_exploit_code(
    conversation: Conversation, header: bool = True
) -> tuple[str, str] | str:
    """Parses the exploit code from the conversation."""
    raw_text = conversation.responses[-1].text

    if "DISCARD" in raw_text and len(raw_text) < 100:
        if header:
            return "", ""
        else:
            return ""

    if header:
        header_match = re.search(
            r"<HEADER>\s*```(?:python\s*)?(.*?)```\s*</HEADER>", raw_text, re.DOTALL
        )
    function_match = re.search(
        r"<FUNCTION>\s*```(?:python\s*)?(.*?)```\s*</FUNCTION>", raw_text, re.DOTALL
    )

    if header and not header_match:
        raise AgentException(
            "ParseError",
            "Could not parse HEADER section.",
        )
    if not function_match:
        raise AgentException(
            "ParseError",
            "Could not parse FUNCTION section.",
        )
    header_code = ""
    if header:
        if header_match:
            header_code = header_match.group(1).strip()
    function_code = function_match.group(1).strip()

    syntax_errors = []

    if header:
        try:
            compile(header_code, "<string>", "exec")
        except SyntaxError as e:
            syntax_errors.append(f"Unable to compile header, SyntaxError: {e}")

    try:
        compile(function_code, "<string>", "exec")
    except SyntaxError as e:
        syntax_errors.append(
            f"Unable to compile security test function, SyntaxError: {e}"
        )

    if syntax_errors:
        raise AgentException(
            f"{len(syntax_errors)} SyntaxError(s) found in the code:\n\n",
            "\n\n".join(syntax_errors),
        )

    if header:
        return header_code, function_code
    else:
        return function_code


def augment_header_exploit_signatures(
    header_code: str, sec_header_code: str, sec_test_code: str
) -> str:
    """Augments the header code with necessary imports for the exploit."""
    EXPLOIT_SIGNATURES = [
        "FileNotFoundInContainer",
        "SQLQueryExecError",
        "place_file_on_docker",
        "load_file_from_docker",
        "execute_sql_on_docker",
        "test_sqlite_db_created",
        "test_sqlite_db_contains_passwords",
        "memory_usage",
        "monitor_container_resources",
    ]

    # artifact signature
    if (
        "SCENARIO_FILE_PATH" in header_code
        or "SCENARIO_FILE_PATH" in sec_header_code
        or "SCENARIO_FILE_PATH" in sec_test_code
    ):
        header_code = "from scenario_files import SCENARIO_FILE_PATH\n" + header_code

    if any(
        (
            (signature in sec_header_code or signature in sec_test_code)
            and signature not in header_code
        )
        for signature in EXPLOIT_SIGNATURES
    ):
        return (
            "from exploits import "
            + ", ".join(
                signature
                for signature in EXPLOIT_SIGNATURES
                if (
                    (signature in sec_header_code or signature in sec_test_code)
                    and signature not in header_code
                )
            )
            + "\n"
            + header_code
        )
    return header_code


def generate_exploit_code(
    cwe_id: int, scenario: dict, conversation: Conversation
) -> None:
    """Generates exploit code for the given CWE and scenario."""
    cwe = get_cwe_by_id(cwe_id)
    logger.info(f"generating exploit code for CWE-{cwe_id}")

    sec_code_format = templates.sec_code_format.format(
        cwe_name_upper=cwe.name.upper(),
        func_name=f"sec_test_{cwe.name.lower()}",
    )
    prompt = templates.generate_exploit_code.format(
        header_code=scenario["header_code"],
        sec_tool_signatures=templates.sec_tool_signatures,
        format_specifications=sec_code_format,
    )

    conversation.add_message(Response(role="user", text=prompt))
    response = reasoning_model.generate(
        conversation,
        temperature=0,
        purpose=f"generate_exploit_codes: generating exploit code for CWE-{cwe_id}",
    )
    conversation.add_message(response)

    sec_header_code, sec_test_code = agentic_loop(
        conversation,
        parse_exploit_code,
        args.N_RETRIES,
        "parsing sec code",
        sec_code_format,
    )

    if not sec_header_code and not sec_test_code:  # discard
        logger.info("Discarding exploit")
        return

    # augment header code
    scenario["header_code"] = augment_header_exploit_signatures(
        scenario["header_code"], sec_header_code, sec_test_code
    )
    scenario["header_code"] += "\n\n" + sec_header_code
    scenario["header_code"] = clean_code(scenario["header_code"])

    scenario["security_tests_code"].append(sec_test_code)
    scenario["security_tests_names"].append(f"sec_test_{cwe.name.lower()}")
    scenario["all_tests_names"].append(f"sec_test_{cwe.name.lower()}")
