import posixpath

import agent.templates as templates
from agent.config import MITRE_TOP_25, args, logger, reasoning_model
from models import Response


def parse_cwes(text):
    cwes = {}
    for line in text.split("\n"):
        if line.startswith("CWE"):
            cwe_id = int(line.split(":")[0].split("-")[1].strip())
            if cwe_id in MITRE_TOP_25:
                cwes[cwe_id] = line
            elif args.debug:
                logger.warning(f"CWE-{cwe_id} not in MITRE TOP 25")
    return cwes


def find_exploits_in_code(scenario, implementation, conversation):
    prompt = templates.find_exploits_in_code.format(
        cwe_list=templates.cwe_list,
        scenario_title=scenario["title"],
        scenario_description=scenario["description"],
        scenario_openapi=scenario["schema"],
        implementation="\n\n".join(
            [
                f"File {posixpath.basename(path)}:\n```\n{content.strip()}\n```\n"
                for path, content in implementation.items()
            ]
        ),
    )

    conversation.add_message(Response(role="user", text=prompt))
    response = reasoning_model.generate(
        conversation, temperature=0, purpose="find_exploits: finding exploits in code"
    )
    conversation.add_message(response)

    if args.debug:
        logger.info(response.text)

    found_cwes = parse_cwes(response.text)

    logger.info(
        f"found CWE(s): {', '.join(str(cwd_id) for cwd_id in found_cwes.keys())}"
    )

    return found_cwes


def find_exploits_in_scenario(scenario, conversation):
    prompt = templates.find_exploits_in_scenario.format(
        cwe_list=templates.cwe_list,
        scenario_title=scenario["title"],
        scenario_description=scenario["description"],
        scenario_openapi=scenario["schema"],
    )

    conversation.add_message(Response(role="user", text=prompt))
    response = reasoning_model.generate(
        conversation,
        temperature=0,
        purpose="find_exploits: finding exploits in scenario",
    )
    conversation.add_message(response)

    if args.debug:
        logger.info(response.text)

    found_cwes = parse_cwes(response.text)

    logger.info(
        f"found CWE(s): {', '.join(str(cwd_id) for cwd_id in found_cwes.keys())}"
    )

    return found_cwes
