# flake8: noqa F821
# mypy: disable-error-code=name-defined
import glob
import json
import os
from collections import Counter
from itertools import cycle, islice
from pathlib import PosixPath

from natsort import natsorted

from agent import export_scenario_code
from agent.config import args, get_baxbench_args, logger, scenario_folder_path
from agent.find_exploits import find_exploits_in_code, find_exploits_in_scenario
from agent.generate_exploit_codes import generate_exploit_code
from agent.generate_exploit_strategies import (
    compliance_exploit_strategy,
    generate_exploit_strategy,
)
from agent.iterate_exploits import iteration_sec_test
from agent.utils import (
    clean_code,
    save_code,
    test_and_evaluate_baxbench,
    visualize_baxbench_eval,
)
from models import Conversation


def impl_passes_all_functional_tests(
    scenario: dict, impl_key: str, full_results: dict
) -> bool:
    """Checks if an implementation passes all functional tests."""
    # Check if the implementation passes all functional tests
    for test in scenario["functional_tests_names"]:
        assert test in full_results.keys(), f"Test {test} not found in results"
        if full_results[test][impl_key]["status"] != "passed":
            return False
    return True


def generate_exploits() -> None:
    """Generates and verifies exploits for the scenario."""
    logger.info("generating exploits")

    with open(
        os.path.join(scenario_folder_path, f"{args.scenario}_tasklist.json"), "rb"
    ) as file:
        task_dict = json.load(file)

    # load implementations
    implementation_files = glob.glob(
        os.path.join(scenario_folder_path, f"{args.scenario}_implementations_*.json")
    )
    if not implementation_files:
        assert False, "No implementation files found"
    implementation_file = natsorted(implementation_files)[-1]
    with open(implementation_file, "r") as file:
        raw_implementations = json.load(file)
        implementations = {
            k: {PosixPath(path): code for path, code in v.items()}
            for k, v in raw_implementations.items()
        }

    if os.path.exists(os.path.join(scenario_folder_path, f"{args.scenario}_iw0.json")):
        with open(
            os.path.join(scenario_folder_path, f"{args.scenario}_iw0.json"), "r"
        ) as file:
            scenario = json.load(file)
    else:
        scenario_files = glob.glob(
            os.path.join(scenario_folder_path, f"{args.scenario}_i*.json")
        )

        latest_scenario_file = natsorted(scenario_files)[-1]

        with open(latest_scenario_file, "r") as file:
            scenario = json.load(file)

        logger.info(f"loaded scenario file: {latest_scenario_file}")
        conversations = []
        searches = []

        conversation = Conversation()
        searches.append(find_exploits_in_scenario(scenario, conversation))
        conversations.append(conversation)

        for key in task_dict.keys():
            conversation = Conversation()
            searches.append(
                find_exploits_in_code(scenario, implementations[key], conversation)
            )
            conversations.append(conversation)

        assert len(searches) == len(conversations) == len(task_dict) + 1

        scenario["header_code"] = clean_code(
            "import uuid\ngenerate_marker = lambda: uuid.uuid4().hex\nimport cwes\n"
            + scenario["header_code"]
        )
        scenario["security_tests_code"] = []
        scenario["security_tests_names"] = []

        pooled_cwes: dict[str, list[str]] = {}
        for found_cwes in searches:
            for cwe_id, cwe_line in found_cwes.items():
                if cwe_id not in pooled_cwes:
                    pooled_cwes[cwe_id] = []
                pooled_cwes[cwe_id].append(cwe_line)

        approaches = {}
        for cwe_id, cwe_lines in pooled_cwes.items():
            approaches[cwe_id] = "\n".join(
                f"Approach for {cwe_line}" for cwe_line in cwe_lines
            )

            conversation = Conversation()
            exploit_strategy = generate_exploit_strategy(
                cwe_id, approaches[cwe_id], scenario, conversation
            )

            exploit_strategy = compliance_exploit_strategy(
                exploit_strategy, scenario, conversation
            )

            if args.debug:
                logger.info(exploit_strategy)

            if exploit_strategy:
                generate_exploit_code(cwe_id, scenario, conversation)

        # export scenario json
        with open(
            os.path.join(scenario_folder_path, f"{args.scenario}_iw0.json"), "w"
        ) as file:
            json.dump(scenario, file, indent=4)

    code = export_scenario_code(scenario, it=0, write=True, sec=True)
    exec(code, globals())

    # iterate at this point
    it = 1
    i = 0
    # for sec_test, test_results in list(full_results.items()):
    N_SEC_TESTS = len(scenario["security_tests_names"])
    for _ in range(N_SEC_TESTS):
        sec_test = scenario["security_tests_names"][i]
        # reload implementations again per sec test iteration
        with open(implementation_file, "r") as file:
            raw_implementations = json.load(file)
            implementations = {
                k: {PosixPath(path): code for path, code in v.items()}
                for k, v in raw_implementations.items()
            }
            # make sure this impl is in the correct baxbench dir
            for key in task_dict.keys():
                save_code(implementations[key], task_dict[key])

        full_results = test_and_evaluate_baxbench(SCENARIO)

        test_results = full_results[sec_test]

        visualize_baxbench_eval(
            full_results, 0, iw=True, x_axis_labels=scenario["all_tests_names"]
        )

        logger.info(f"Iterating {sec_test}")
        flags = {"flag1": False, "flag2": False}

        # filter for implementations that passed all functional tests
        impls_filtered = list(
            filter(
                lambda key: impl_passes_all_functional_tests(
                    scenario, key, full_results
                ),
                list(task_dict.keys()),
            )
        )

        assert impls_filtered, "No implementation passes all functional tests"
        impl_pool = cycle(impls_filtered)

        impl_key = next(impl_pool)
        logger.info(f"Rotating to initial implementation {impl_key} in the pool")

        for _ in range(args.N_SEC_STEPS):
            discard = iteration_sec_test(
                scenario,
                i,
                sec_test,
                test_results,
                implementations,
                flags,
                task_dict,
                impl_key,
            )
            if discard:
                logger.info(f"Discarding {sec_test}")
                del scenario["security_tests_code"][i]
                del scenario["security_tests_names"][i]

            logger.info(f"Flags: {flags}")
            if flags["flag1"] and flags["flag2"]:
                break

            assert (discard and not (flags["flag1"] or flags["flag2"])) or (not discard)

            code = export_scenario_code(
                scenario, it=it, write=True, sec=True
            )  # ,sec_i=i)
            exec(code, globals())
            with open(
                os.path.join(scenario_folder_path, f"{args.scenario}_iw{it}.json"), "w"
            ) as file:
                json.dump(scenario, file, indent=4)

            if flags["flag1"] or flags["flag2"]:
                # just re-test the exploit for the implementation that changed and not the other
                full_results = test_and_evaluate_baxbench(
                    SCENARIO, [impl_key.split()[-1]]
                )

            else:
                # re-test all implementations since exploit changed
                full_results = test_and_evaluate_baxbench(SCENARIO)

            if not discard:
                test_results = full_results[sec_test]

            # save results, save impls, visualize results
            with open(
                os.path.join(
                    scenario_folder_path, f"{args.scenario}_results_iw{it}.json"
                ),
                "w",
            ) as file:
                json.dump(full_results, file, indent=4)
            with open(
                os.path.join(
                    scenario_folder_path,
                    f"{args.scenario}_implementations_iw{it}.json",
                ),
                "w",
            ) as file:
                json.dump(
                    {
                        k: {str(path): code for path, code in v.items()}
                        for k, v in implementations.items()
                    },
                    file,
                    indent=4,
                )

            visualize_baxbench_eval(
                full_results,
                it,
                iw=True,
                highlight_x=sec_test if not discard else None,
                x_axis_labels=scenario["all_tests_names"],
            )

            # regularize with functional correctness
            # reset all implementations that failed functional tests
            prev_impl_key = impl_key
            while not impl_passes_all_functional_tests(
                scenario, impl_key, full_results
            ):
                logger.warning(
                    f"Implementation {impl_key} failed functional tests while iterating {sec_test}, resetting it"
                )
                # reload functionally correct implementation again
                with open(implementation_file, "r") as file:
                    raw_implementations = json.load(file)
                    implementations[impl_key] = {
                        PosixPath(path): code
                        for path, code in raw_implementations[impl_key].items()
                    }
                    # make sure this impl is in the correct baxbench dir
                    save_code(implementations[impl_key], task_dict[impl_key])

                # rotate to the next implementation in the pool
                impl_key = next(impl_pool)
                logger.info(f"Rotating to next implementation {impl_key} in the pool")
                if impl_key == prev_impl_key:
                    break

            it += 1
            if discard:
                break

        if flags["flag1"] and flags["flag2"]:
            logger.info(f"Successfully iterated {sec_test}, both flags are True")
            i += 1
        elif not discard:
            logger.info(
                f"Failed to iterate {sec_test}, flag1={flags['flag1']}, flag2={flags['flag2']}"
            )
            del scenario["security_tests_code"][i]
            del scenario["security_tests_names"][i]

            code = export_scenario_code(
                scenario, it=it, write=True, sec=True
            )  # ,sec_i=i)
            exec(code, globals())
            with open(
                os.path.join(scenario_folder_path, f"{args.scenario}_iw{it}.json"), "w"
            ) as file:
                json.dump(scenario, file, indent=4)

            full_results = test_and_evaluate_baxbench(
                SCENARIO,
                [key.split()[-1] for key in islice(impl_pool, len(impls_filtered))],
            )

            # save results, save impls, visualize results
            with open(
                os.path.join(
                    scenario_folder_path, f"{args.scenario}_results_iw{it}.json"
                ),
                "w",
            ) as file:
                json.dump(full_results, file, indent=4)
            with open(
                os.path.join(
                    scenario_folder_path,
                    f"{args.scenario}_implementations_iw{it}.json",
                ),
                "w",
            ) as file:
                json.dump(
                    {
                        k: {str(path): code for path, code in v.items()}
                        for k, v in implementations.items()
                    },
                    file,
                    indent=4,
                )
            visualize_baxbench_eval(
                full_results, it, iw=True, x_axis_labels=scenario["all_tests_names"]
            )
            it += 1
