import os
import posixpath
import re

import pandas as pd

import agent.templates as templates
import cwes
from agent.config import args, logger, reasoning_model
from agent.generate_exploit_codes import (
    augment_header_exploit_signatures,
    parse_exploit_code,
)
from agent.iterate_functional_tests import parse_implementation
from agent.utils import (
    AgentException,
    agentic_loop,
    clean_code,
    record_verdict,
    save_code,
)
from models import Conversation, Response


def parse_verdict(conversation: Conversation) -> int:
    """Parses the boolean verdict from the conversation."""
    verdict = re.search(
        r"<VERDICT>\s*(\d+)\s*</VERDICT>", conversation.responses[-1].text
    )
    if verdict:
        if verdict.group(1) not in ["1", "2"]:
            raise AgentException("ParseError", "Invalid verdict")
        return int(verdict.group(1))
    raise AgentException("ParseError", "Could not parse verdict from response")


def iterate_green_red(
    scenario: dict,
    i: int,
    implementation: dict,
    test_result: dict,
    conversation: Conversation,
    green: bool = True,
) -> int:
    """Makes a verdict on the security test, based on the implementation and test results."""
    df = pd.read_csv(os.path.join(args.path, "cwe_db.csv"), index_col=0)
    cwe = getattr(
        cwes.CWE, "_".join(scenario["security_tests_names"][i].split("_")[2:]).upper()
    )
    cwe_id = cwe.value["num"]

    if green:
        prompt = templates.iterate_green.format(
            scenario_title=scenario["title"],
            scenario_description=scenario["description"],
            scenario_openapi=scenario["schema"],
            implementation="\n\n".join(
                [
                    f"File {posixpath.basename(path)}:\n```\n{content.strip()}\n```\n"
                    for path, content in implementation.items()
                ]
            ),
            sec_test_name=scenario["security_tests_names"][i],
            cwe=f"CWE-{cwe_id}: {df.loc[cwe_id, 'name']}",
            header_code=scenario["header_code"],
            security_code=scenario["security_tests_code"][i],
            test_logs=test_result["test_logs"],
            container_logs=test_result["container_logs"],
        )
    else:
        prompt = templates.iterate_red.format(
            scenario_title=scenario["title"],
            scenario_description=scenario["description"],
            scenario_openapi=scenario["schema"],
            implementation="\n\n".join(
                [
                    f"File {posixpath.basename(path)}:\n```\n{content.strip()}\n```\n"
                    for path, content in implementation.items()
                ]
            ),
            sec_test_name=scenario["security_tests_names"][i],
            cwe=f"CWE-{cwe_id}: {df.loc[cwe_id, 'name']}",
            header_code=scenario["header_code"],
            security_code=scenario["security_tests_code"][i],
            test_logs=test_result["test_logs"],
            container_logs=test_result["container_logs"],
        )

    conversation.add_message(Response(role="user", text=prompt))
    response = reasoning_model.generate(
        conversation,
        temperature=0,
        purpose=f"iterate_exploits: iterating security test CWE-{cwe_id}",
    )
    conversation.add_message(response)

    return agentic_loop(
        conversation,
        parse_verdict,
        args.N_RETRIES,
        "parsing verdict",
        "Respond with 1 or 2, wrapped in <VERDICT> tags.",
    )


def flip_impl(
    make_vulnerable: bool,
    scenario: dict,
    i: int,
    implementation: dict,
    framework: list[str],
) -> dict:
    """Modifies the implementation to be vulnerable or secure."""
    df = pd.read_csv(os.path.join(args.path, "cwe_db.csv"), index_col=0)
    cwe = getattr(
        cwes.CWE, "_".join(scenario["security_tests_names"][i].split("_")[2:]).upper()
    )
    cwe_id = cwe.value["num"]

    impl_format = templates.flip_impl_format.format(
        format_specifications="\n\n".join(
            [
                f"<{posixpath.basename(path)}>\n```\n```\n</{posixpath.basename(path)}>\n"
                for path in implementation.keys()
            ]
        )
    )

    if make_vulnerable:
        logger.info("Flipping implementation to be vulnerable.")
        prompt = templates.flip_impl_vulnerable.format(
            framework=framework,
            cwe=f"CWE-{cwe_id} - {df.loc[cwe_id, 'name']}",
            scenario_title=scenario["title"],
            scenario_description=scenario["description"],
            scenario_openapi=scenario["schema"],
            format_specifications=impl_format,
            implementation="\n\n".join(
                [
                    f"File {posixpath.basename(path)}:\n```\n{content.strip()}\n```\n"
                    for path, content in implementation.items()
                ]
            ),
        )
    else:
        logger.info("Flipping implementation to be secure.")
        prompt = templates.flip_impl_secure.format(
            framework=framework,
            cwe=f"CWE-{cwe_id} - {df.loc[cwe_id, 'name']}",
            scenario_title=scenario["title"],
            scenario_description=scenario["description"],
            scenario_openapi=scenario["schema"],
            format_specifications=impl_format,
            implementation="\n\n".join(
                [
                    f"File {posixpath.basename(path)}:\n```\n{content.strip()}\n```\n"
                    for path, content in implementation.items()
                ]
            ),
        )

    conversation = Conversation().add_message(Response(role="user", text=prompt))
    response = reasoning_model.generate(
        conversation,
        temperature=0,
        purpose=f"iterate_exploits: flipping implementation for CWE-{cwe_id}",
    )
    conversation.add_message(response)

    return agentic_loop(
        conversation,
        lambda c: parse_implementation(c, implementation),
        args.N_RETRIES,
        "parsing re-implementation",
        impl_format,
    )


def fix_exploit(conversation: Conversation, scenario: dict, i: int) -> str:
    """Fixes the exploit code based on the conversation."""
    prompt = templates.fix_sec_code.format(
        sec_tool_signatures=templates.sec_tool_signatures,
        format_specifications=templates.fix_sec_code_format.format(
            sec_test_name=scenario["security_tests_names"][i]
        ),
    )

    conversation.add_message(Response(role="user", text=prompt))
    response = reasoning_model.generate(
        conversation,
        temperature=0,
        purpose=f"iterate_exploits: fixing {scenario['security_tests_names'][i]}",
    )
    conversation.add_message(response)

    sec_test_code = agentic_loop(
        conversation,
        lambda c: parse_exploit_code(c, header=False),
        args.N_RETRIES,
        "processing fixed exploit",
        templates.fix_sec_code_format.format(
            sec_test_name=scenario["security_tests_names"][i]
        ),
    )

    scenario["header_code"] = clean_code(
        augment_header_exploit_signatures(scenario["header_code"], "", sec_test_code)
    )
    return sec_test_code


# a single iteration of security test improvement
def iteration_sec_test(
    scenario: dict,
    i: int,
    sec_test: str,
    test_results: dict,
    implementations: dict,
    flags: dict,
    task_dict: dict,
    key: str,
) -> bool:
    """Iterates on a security test to ensure it correctly identifies vulnerabilities."""
    # return: discard (True/False)

    test_result = test_results[key]
    # for key, test_result in test_results.items():
    conversation = Conversation()
    if test_result["status"] == "passed":
        verdict = iterate_green_red(
            scenario, i, implementations[key], test_result, conversation, green=True
        )

        if verdict == 2:  # true negative
            record_verdict("Security TN", conversation.responses[-1].text.strip())
            logger.info(f"Impl {key} of {sec_test}: TN")
            flags["flag1"] = True
            if flags["flag1"] and flags["flag2"]:
                return False  # test successfully iterated, no need to flip again
            implementations[key] = flip_impl(
                True, scenario, i, implementations[key], key.split()[:-1]
            )
            save_code(implementations[key], task_dict[key])
        else:
            record_verdict("Security FN", conversation.responses[-1].text.strip())
    elif test_result["status"] == "failed":
        verdict = iterate_green_red(
            scenario, i, implementations[key], test_result, conversation, green=False
        )

        if verdict == 2:  # true positive
            record_verdict("Security TP", conversation.responses[-1].text.strip())
            logger.info(f"Impl {key} of {sec_test}: TP")
            flags["flag2"] = True
            if flags["flag1"] and flags["flag2"]:
                return False  # test successfully iterated, no need to flip again
            implementations[key] = flip_impl(
                False, scenario, i, implementations[key], key.split()[:-1]
            )
            save_code(implementations[key], task_dict[key])
        else:
            record_verdict("Security FP", conversation.responses[-1].text.strip())

    if (
        test_result["status"] == "exception" or verdict == 1
    ):  # false negative / positive
        if test_result["status"] == "exception":
            logger.info(
                "Exception occured in the application code when running the exploit"
            )
        else:
            logger.info(
                f"Impl {key} of {sec_test}: {'FN' if test_result['status'] == 'passed' else 'FP'}"
            )
        flags["flag1"] = False
        flags["flag2"] = False
        security_test_code = fix_exploit(conversation, scenario, i)

        if security_test_code:
            scenario["security_tests_code"][i] = security_test_code
            return False
        else:  # discard security test
            return True
    return False
