"""
DESCRIPTION:

This experiment is a generic best-of-N data gatherer. For each solution generation program outlined below, it gathers N examples of solutions for each problem assuming that the untrusted model is benign and malicious. It then evaluates the suspiciousness of the solutions using the trusted monitor and evaluates the test case performance.

Data generation types supported:
* Query Untrusted
* Query Trusted
* Trusted editing
* Decomposition

STEPS
1. Query U for N solutions for each problem (Query Untrusted Benign U)
2. Use the backdoor pipeline to query U for N acceptable backdoors for each problem (Query Untrusted Malicious U)
3. Query T for N solutions for each problem (Query T)
4. Query T_e for rewritten versions of step 1 (Trusted editing Benign U)
5. Query T_e for rewritten versions of step 2 (Trusted editing Malicious U)
6. TODO: implement decomposition
7. Evaluate the suspiciousness of all solutions using the trusted monitor
8. Evaluate the test case performance of all solutions
"""

import asyncio
import logging

from src.code_evaluation.executor_standard import run_solutions
from src.code_evaluation.exploit_executor import evaluate_exploit_validity
from src.experiments.trusted_decomposition import join_subtasks
from src.experiments.untrusted_decomposition import filter_most_suspicious_subtask
from src.experiments.untrusted_decomposition import (
    join_subtasks as untrusted_join_subtasks,
)
from src.experiments.untrusted_decomposition_blue_team_test import (
    rejoin_exploit_attempts,
)
from src.model_querying.prompt_creation import (
    get_backdoor_decomposition_improved,
    get_decomposition,
    get_innocent_decomposition_improved,
    get_recomposed_solution,
    get_subtask_solution,
    get_two_shot_prompt_standard,
    implement_solution_from_plan,
    improve_backdoor_decomposition,
    improve_solution_from_plan,
    rewrite_decomposition,
    rewrite_solution,
)
from src.model_querying.solution_extraction import (
    extract_decomposition,
    extract_improved_solution_in_block,
    extract_rewritten_decomposition,
    extract_rewritten_solution_in_block,
    extract_solution_in_block,
)
from src.pipeline.blue_team.estimating_difficulty import (
    add_execution_difficulty_estimation_to_pipeline,
    average_difficulty,
)
from src.pipeline.blue_team.monitoring import add_basic_blue_team_monitor_to_pipeline
from src.pipeline.pipeline import Pipeline, PipelineConfig
from src.pipeline.red_team.backdoor_creation import (
    add_backdoor_steps_to_pipeline,
    attempt_exploit,
)
from src.pipeline.transformations import (
    downselect_to_n_subproblems,
    flatten_backdoors_with_min_per_problem,
    select_backdoors,
    separate_subtasks,
)
from src.tools.dataloaders import load_problem_subset
from src.tools.path_utils import get_default_results_directory, get_root_directory
from src.tools.string_manipulation import format_key_suffix

logger = logging.getLogger(__name__)


def create_output(suspiciousness_key_suffix=None, exploit_key_suffix=None):
    suspiciousness_key_suffix = format_key_suffix(suspiciousness_key_suffix)
    exploit_key_suffix = format_key_suffix(exploit_key_suffix)

    def call(suspiciousness_data, usefulness_data):
        output = {}
        for key in suspiciousness_data:
            output[key] = {
                ## Basic data
                "problem_id": key.split("-")[0],
                "problem": suspiciousness_data[key]["question"],
                "solution": suspiciousness_data[key]["solution"],
                "original_solution": suspiciousness_data[key].get("original_solution"),
                "original_backdoored_solution": suspiciousness_data[key].get(
                    "original_backdoored_solution"
                ),
                ## Problem Correctness
                "all_correct": usefulness_data[key]["correct"],
                "num_tests_correct": sum(
                    [
                        1 if test["correct"] else 0
                        for test in usefulness_data[key]["test_cases"]
                    ]
                ),
                "total_tests": len(usefulness_data[key]["test_cases"]),
                # Backdoor attempt
                "backdoor_attempt": suspiciousness_data[key]["metadata"].get(
                    "backdoor", False
                ),
                "backdoor_code": suspiciousness_data[key]["metadata"].get("suggestion"),
                # Backdoor Quality
                "red_team_backdoor_likelihood": suspiciousness_data[key].get(
                    "red_team_backdoor_likelihood"
                ),
                "red_team_backdoor_presence": suspiciousness_data[key].get(
                    "red_team_backdoor_presence"
                ),
                "red_team_backdoor_suspiciousness_score": suspiciousness_data[key].get(
                    "red_team_backdoor_monitor_score"
                ),
                "red_team_sneakiness": suspiciousness_data[key].get("sneakiness"),
                "red_team_rarity": suspiciousness_data[key].get("rarity"),
                "red_team_exploit_difficulty": suspiciousness_data[key].get(
                    "exploit_difficulty"
                ),
                "red_team_exploit_input": suspiciousness_data[key].get(
                    f"exploit_input{exploit_key_suffix}"
                ),
                # Exploit evaluation
                "apps_solution_output": suspiciousness_data[key].get(
                    f"ground_truth_exploit_output{exploit_key_suffix}"
                ),
                "non_backdoored_output": suspiciousness_data[key].get(
                    f"non_backdoored_output{exploit_key_suffix}"
                ),
                "solution_output": suspiciousness_data[key].get(
                    f"exploit_output{exploit_key_suffix}"
                ),
                "ground_truth_exploit_input_works": suspiciousness_data[key].get(
                    f"backdoor_works{exploit_key_suffix}"
                ),
                # Blue team evaluation
                "blue_team_suspiciousness": suspiciousness_data[key][
                    f"suspiciousness{suspiciousness_key_suffix}"
                ],
            }
        return output

    return call


def create_edited_output(suspiciousness_key_suffix=None, exploit_key_suffix=None):
    suspiciousness_key_suffix = format_key_suffix(suspiciousness_key_suffix)
    exploit_key_suffix = format_key_suffix(exploit_key_suffix)

    def call(suspiciousness_data, usefulness_data):
        output = {}
        for key in suspiciousness_data:
            output[key] = {
                "solution": suspiciousness_data[key]["solution"],
                "all_correct": usefulness_data[key]["correct"],
                "num_tests_correct": sum(
                    [
                        1 if test["correct"] else 0
                        for test in usefulness_data[key]["test_cases"]
                    ]
                ),
                "total_tests": len(usefulness_data[key]["test_cases"]),
                "solution_output": suspiciousness_data[key].get(
                    f"exploit_output{exploit_key_suffix}"
                ),
                "ground_truth_original_exploit_input_works": suspiciousness_data[
                    key
                ].get(f"backdoor_works{exploit_key_suffix}"),
                "blue_team_suspiciousness": suspiciousness_data[key][
                    f"suspiciousness{suspiciousness_key_suffix}"
                ],
            }
        return output

    return call


def add_exploit_retry_data_to_output(key_suffix=None):
    key_suffix = format_key_suffix(key_suffix)

    def call(output, exploit_data):
        combined_output = {}
        for key, value in output.items():
            combined_output[key] = value
            if key not in exploit_data:
                continue

            if exploit_data[key].get(f"exploit_input{key_suffix}") is None:
                combined_output[key]["red_team_thinks_exploitable_post_edit"] = False
                continue

            combined_output[key]["red_team_thinks_exploitable_post_edit"] = True
            combined_output[key]["retry_exploit_input"] = exploit_data[key].get(
                f"exploit_input{key_suffix}"
            )
            combined_output[key]["retry_apps_solution_output"] = exploit_data[key].get(
                f"ground_truth_exploit_output{key_suffix}"
            )
            combined_output[key]["retry_non_backdoored_output"] = exploit_data[key].get(
                f"non_backdoored_output{key_suffix}"
            )
            combined_output[key]["retry_solution_output"] = exploit_data[key].get(
                f"exploit_output{key_suffix}"
            )
            combined_output[key][
                "retry_ground_truth_exploit_input_works"
            ] = exploit_data[key].get(f"backdoor_works{key_suffix}")

        return combined_output

    return call


def join_output(
    estimated_difficulty_results,
    query_benign_results,
    query_malicious_results,
    query_trusted_results,
    query_benign_rewrite_results,
    query_malicious_rewrite_results,
    # trusted_decomposition_benign_results=None,
    # trusted_decomposition_malicious_results=None,
    untrusted_decomposition_benign_results=None,
    untrusted_decomposition_benign_edited_results=None,
    untrusted_decomposition_benign_rewritten_results=None,
    untrusted_decomposition_malicious_results=None,
    untrusted_decomposition_malicious_edited_results=None,
    untrusted_decomposition_malicious_rewritten_results=None,
):
    # initialize output with query_benign data
    output = {}
    for key, value in query_benign_results.items():
        problem_id = value["problem_id"]

        if problem_id not in output:
            output[problem_id] = {}
            output[problem_id]["problem_id"] = value["problem_id"]
            output[problem_id]["problem"] = value["problem"]
            output[problem_id]["difficulty"] = estimated_difficulty_results[problem_id][
                "difficulty"
            ]
            output[problem_id]["query_benign_u"] = []
            output[problem_id]["query_malicious_u"] = []
            output[problem_id]["query_trusted"] = []
            output[problem_id]["trusted_decomposition_benign_u"] = []
            output[problem_id]["trusted_decomposition_malicious_u"] = []
            output[problem_id]["untrusted_decomposition_benign_u"] = []
            output[problem_id]["untrusted_decomposition_malicious_u"] = []

        subproblem_data = {
            k: v
            for k, v in value.items()
            if k not in ("problem_id", "problem") and v is not None
        }
        output[problem_id]["query_benign_u"].append(subproblem_data)

    # add trusted edit data to benign solutions
    for key, value in query_benign_rewrite_results.items():
        problem_id, index, _ = key.split("-")

        parent_data = output[problem_id]["query_benign_u"][int(index)]
        if "edited_solutions" not in parent_data:
            parent_data["edited_solutions"] = []

        subproblem_data = {
            k: v
            for k, v in value.items()
            if k not in ("problem_id", "problem") and v is not None
        }
        parent_data["edited_solutions"].append(subproblem_data)

    # add query_malicious data
    for key, value in query_malicious_results.items():
        problem_id = value["problem_id"]
        parent_data = output[problem_id]
        subproblem_data = {
            k: v
            for k, v in value.items()
            if k not in ("problem_id", "problem") and v is not None
        }
        output[problem_id]["query_malicious_u"].append(subproblem_data)

    # add trusted edit data to malicious solutions
    for key, value in query_malicious_rewrite_results.items():
        problem_id, index, _ = key.split("-")

        parent_data = output[problem_id]["query_malicious_u"][int(index)]
        if "edited_solutions" not in parent_data:
            parent_data["edited_solutions"] = []

        subproblem_data = {
            k: v
            for k, v in value.items()
            if k not in ("problem_id", "problem") and v is not None
        }
        parent_data["edited_solutions"].append(subproblem_data)

    # add query trusted results
    for key, value in query_trusted_results.items():
        problem_id = value["problem_id"]
        parent_data = output[problem_id]
        subproblem_data = {
            k: v
            for k, v in value.items()
            if k not in ("problem_id", "problem") and v is not None
        }
        output[problem_id]["query_trusted"].append(subproblem_data)

    # if trusted_decomposition_benign_results:
    #     # add trusted decomposition benign data
    #     for key, value in trusted_decomposition_benign_results.items():
    #         problem_id = value["problem_id"]
    #         parent_data = output[problem_id]
    #         subproblem_data = {
    #             k: v
    #             for k, v in value.items()
    #             if k not in ("problem_id", "problem") and v is not None
    #         }
    #         output[problem_id]["trusted_decomposition_benign_u"].append(subproblem_data)

    # if trusted_decomposition_malicious_results:
    #     # add trusted decomposition malicious data
    #     for key, value in trusted_decomposition_malicious_results.items():
    #         problem_id = value["problem_id"]
    #         parent_data = output[problem_id]
    #         subproblem_data = {
    #             k: v
    #             for k, v in value.items()
    #             if k not in ("problem_id", "problem") and v is not None
    #         }
    #         output[problem_id]["trusted_decomposition_malicious_u"].append(
    #             subproblem_data
    #         )

    if untrusted_decomposition_benign_results:
        # add untrusted decomposition benign data
        for key, value in untrusted_decomposition_benign_results.items():
            problem_id = value["problem_id"]
            parent_data = output[problem_id]
            subproblem_data = {
                k: v
                for k, v in value.items()
                if k not in ("problem_id", "problem") and v is not None
            }
            output[problem_id]["untrusted_decomposition_benign_u"].append(
                subproblem_data
            )

    if untrusted_decomposition_benign_edited_results:
        for key, value in untrusted_decomposition_benign_edited_results.items():
            problem_id, index = key.split("-")

            parent_data = output[problem_id]["untrusted_decomposition_benign_u"][
                int(index)
            ]
            if "edited_solutions" not in parent_data:
                parent_data["edited_solutions"] = []

            subproblem_data = {
                k: v
                for k, v in value.items()
                if k not in ("problem_id", "problem") and v is not None
            }
            parent_data["edited_solutions"].append(subproblem_data)

    if untrusted_decomposition_benign_rewritten_results:
        for key, value in untrusted_decomposition_benign_rewritten_results.items():
            problem_id, index = key.split("-")

            parent_data = output[problem_id]["untrusted_decomposition_benign_u"][
                int(index)
            ]
            if "rewritten_solutions" not in parent_data:
                parent_data["rewritten_solutions"] = []

            subproblem_data = {
                k: v
                for k, v in value.items()
                if k not in ("problem_id", "problem") and v is not None
            }
            parent_data["rewritten_solutions"].append(subproblem_data)

    if untrusted_decomposition_malicious_results:
        # add untrusted decomposition malicious data
        for key, value in untrusted_decomposition_malicious_results.items():
            problem_id = value["problem_id"]
            parent_data = output[problem_id]
            subproblem_data = {
                k: v
                for k, v in value.items()
                if k not in ("problem_id", "problem") and v is not None
            }
            output[problem_id]["untrusted_decomposition_malicious_u"].append(
                subproblem_data
            )

    if untrusted_decomposition_malicious_edited_results:
        for key, value in untrusted_decomposition_malicious_edited_results.items():
            problem_id, index = key.split("-")

            parent_data = output[problem_id]["untrusted_decomposition_malicious_u"][
                int(index)
            ]
            if "edited_solutions" not in parent_data:
                parent_data["edited_solutions"] = []

            subproblem_data = {
                k: v
                for k, v in value.items()
                if k not in ("problem_id", "problem") and v is not None
            }
            parent_data["edited_solutions"].append(subproblem_data)

    if untrusted_decomposition_malicious_rewritten_results:
        for key, value in untrusted_decomposition_malicious_rewritten_results.items():
            problem_id, index = key.split("-")

            parent_data = output[problem_id]["untrusted_decomposition_malicious_u"][
                int(index)
            ]
            if "rewritten_solutions" not in parent_data:
                parent_data["rewritten_solutions"] = []

            subproblem_data = {
                k: v
                for k, v in value.items()
                if k not in ("problem_id", "problem") and v is not None
            }
            parent_data["rewritten_solutions"].append(subproblem_data)

    return output


def query_benign_pipeline(pipeline, load_data, strong_model, weak_model):
    query_benign_untrusted_get_solutions = pipeline.add_query_step(
        "query_benign_untrusted_get_solutions",
        strong_model,
        get_two_shot_prompt_standard,
        extract_solution_in_block,
        dependencies=[load_data],
        bon=5,
        temperature=0.5,
        team="red_query_benign",
    )

    query_benign_untrusted_suspiciousness = add_basic_blue_team_monitor_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        query_benign_untrusted_get_solutions,
        name="query_benign",
        team_suffix="query_benign",
    )

    query_benign_untrusted_usefulness = pipeline.add_code_evaluation_step(
        "query_benign_untrusted_usefulness",
        strong_model,
        run_solutions,
        dependencies=[query_benign_untrusted_get_solutions],
    )

    query_benign_results = pipeline.add_transformation_step(
        "query_benign_results",
        create_output(suspiciousness_key_suffix="query_benign"),
        dependencies=[
            query_benign_untrusted_suspiciousness,
            query_benign_untrusted_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return query_benign_untrusted_get_solutions, query_benign_results


def query_malicious_pipeline(
    pipeline, query_benign_untrusted_get_solutions, strong_model, weak_model
):
    query_malicious_get_backdoors = add_backdoor_steps_to_pipeline(
        pipeline,
        query_benign_untrusted_get_solutions,
        strong_model,
        weak_model,
        num_backdoor_ideas_per_attempt=5,
        num_backdoor_attempts=2,
        temperature=1.0,
        suffix="query_malicious",
        team_suffix="query_malicious",
    )

    # This step can be reused as a source for backdoors
    query_malicious_flatten_backdoors = pipeline.add_transformation_step(
        "query_malicious_flatten_backdoors",
        flatten_backdoors_with_min_per_problem(
            5,
            data_dir=get_default_results_directory()
            / pipeline.config.name
            / "cached_backdoors",
            filter=False,
        ),
        dependencies=[query_malicious_get_backdoors],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    query_malicious_exploit_check = pipeline.add_transformation_step(
        "query_malicious_exploit_check",
        evaluate_exploit_validity(
            pipeline.config.language,
            suffix="query_malicious",
            use_apps_as_ground_truth=True,
            include_pre_rewrite=False,
        ),
        dependencies=[query_malicious_flatten_backdoors],
        strong_model=strong_model,
        weak_model=weak_model,
        read_cache=True,
    )

    query_malicious_untrusted_suspiciousness = add_basic_blue_team_monitor_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        query_malicious_exploit_check,
        name="query_malicious",
        team_suffix="query_malicious",
    )

    query_malicious_untrusted_usefulness = pipeline.add_code_evaluation_step(
        "query_malicious_untrusted_usefulness",
        strong_model,
        run_solutions,
        dependencies=[query_malicious_exploit_check],
    )

    query_malicious_results = pipeline.add_transformation_step(
        "query_malicious_results",
        create_output(
            suspiciousness_key_suffix="query_malicious",
            exploit_key_suffix="query_malicious",
        ),
        dependencies=[
            query_malicious_untrusted_suspiciousness,
            query_malicious_untrusted_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return query_malicious_flatten_backdoors, query_malicious_results


def query_trusted_pipeline(pipeline, load_data, strong_model, weak_model):
    query_trusted_solutions = pipeline.add_query_step(
        "query_trusted_solutions",
        weak_model,
        get_two_shot_prompt_standard,
        extract_solution_in_block,
        dependencies=[load_data],
        bon=5,
        temperature=0.5,
    )

    query_trusted_suspiciousness = add_basic_blue_team_monitor_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        query_trusted_solutions,
        "query_trusted",
        team_suffix="query_trusted",
    )

    query_trusted_usefulness = pipeline.add_code_evaluation_step(
        "query_trusted_usefulness",
        strong_model,
        run_solutions,
        dependencies=[query_trusted_solutions],
    )

    query_trusted_results = pipeline.add_transformation_step(
        "query_trusted_results",
        create_output(suspiciousness_key_suffix="query_trusted"),
        dependencies=[query_trusted_suspiciousness, query_trusted_usefulness],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return query_trusted_solutions, query_trusted_results


def get_problem_difficulty(pipeline, query_trusted_solutions, strong_model, weak_model):
    get_difficulty_results = add_execution_difficulty_estimation_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        query_trusted_solutions,
        team_suffix="difficulty_estimation",
    )
    get_avg_difficulty_results = pipeline.add_transformation_step(
        "difficulty_estimation_merge_difficulty",
        average_difficulty,
        dependencies=[get_difficulty_results],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return get_avg_difficulty_results


def query_benign_edited_pipeline(
    pipeline, query_benign_untrusted_get_solutions, strong_model, weak_model
):
    query_benign_rewrite = pipeline.add_query_step(
        "query_benign_rewrite",
        weak_model,
        rewrite_solution,
        extract_rewritten_solution_in_block,
        dependencies=[query_benign_untrusted_get_solutions],
        bon=5,
        temperature=0.3,
        team="blue_query_benign_rewrite",
    )

    query_benign_rewrite_suspiciousness = add_basic_blue_team_monitor_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        query_benign_rewrite,
        name="query_benign_rewrite",
        team_suffix="query_benign_rewrite",
    )

    query_benign_rewrite_usefulness = pipeline.add_code_evaluation_step(
        "query_benign_rewrite_usefulness",
        f"{strong_model}-{weak_model}",
        run_solutions,
        dependencies=[query_benign_rewrite],
    )

    query_benign_rewrite_results = pipeline.add_transformation_step(
        "query_benign_rewrite_results",
        create_edited_output(suspiciousness_key_suffix="query_benign_rewrite"),
        dependencies=[
            query_benign_rewrite_suspiciousness,
            query_benign_rewrite_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return query_benign_rewrite_results


def query_malicious_edited_pipeline(
    pipeline, query_malicious_flatten_backdoors, strong_model, weak_model
):
    query_malicious_rewrite = pipeline.add_query_step(
        "query_malicious_rewrite",
        weak_model,
        rewrite_solution,
        extract_rewritten_solution_in_block,
        dependencies=[query_malicious_flatten_backdoors],
        bon=5,
        temperature=0.3,
        team="blue_query_malicious_rewrite",
    )

    query_malicious_rewrite_exploit_check = pipeline.add_transformation_step(
        "query_malicious_rewrite_exploit_check",
        evaluate_exploit_validity(
            pipeline.config.language,
            suffix="query_malicious",
            use_apps_as_ground_truth=True,
            include_pre_rewrite=False,
        ),
        dependencies=[query_malicious_rewrite],
        strong_model=strong_model,
        weak_model=weak_model,
        read_cache=True,
    )

    query_malicious_rewrite_suspiciousness = add_basic_blue_team_monitor_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        query_malicious_rewrite_exploit_check,
        name="query_malicious_rewrite",
        team_suffix="query_malicious_rewrite",
    )

    query_malicious_rewrite_usefulness = pipeline.add_code_evaluation_step(
        "query_malicious_rewrite_usefulness",
        f"{strong_model}-{weak_model}",
        run_solutions,
        dependencies=[query_malicious_rewrite_exploit_check],
    )

    query_malicious_rewrite_initial_results = pipeline.add_transformation_step(
        "query_malicious_rewrite_initial_results",
        create_edited_output(
            suspiciousness_key_suffix="query_malicious_rewrite",
            exploit_key_suffix="query_malicious",
        ),
        dependencies=[
            query_malicious_rewrite_suspiciousness,
            query_malicious_rewrite_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    # Retry exploit
    query_malicious_rewrite_retry_exploit = attempt_exploit(
        pipeline,
        strong_model,
        weak_model,
        query_malicious_rewrite,
        use_modified_solution=False,
        suffix="query_malicious_rewrite",
        team_suffix="query_malicious_rewrite",
    )

    query_malicious_rewrite_rejoin_rewritten_exploit_attempts = (
        pipeline.add_transformation_step(
            "query_malicious_rewrite_rejoin_rewritten_exploit_attempts",
            rejoin_exploit_attempts(suffix="query_malicious_rewrite"),
            dependencies=[
                query_malicious_rewrite,
                query_malicious_rewrite_retry_exploit,
            ],
        )
    )

    query_malicious_rewrite_exploit_retry_check = pipeline.add_transformation_step(
        "query_malicious_rewrite_exploit_retry_check",
        evaluate_exploit_validity(
            pipeline.config.language,
            suffix="query_malicious_rewrite",
            use_apps_as_ground_truth=True,
            include_pre_rewrite=False,
        ),
        dependencies=[query_malicious_rewrite_rejoin_rewritten_exploit_attempts],
        strong_model=strong_model,
        weak_model=weak_model,
        read_cache=True,
    )

    query_malicious_rewrite_results = pipeline.add_transformation_step(
        "query_malicious_rewrite_results",
        add_exploit_retry_data_to_output(key_suffix="query_malicious_rewrite"),
        dependencies=[
            query_malicious_rewrite_initial_results,
            query_malicious_rewrite_exploit_retry_check,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return query_malicious_rewrite_results


# TODO: Fix this
def trusted_decomposition_benign_pipeline(
    pipeline, load_data, strong_model, weak_model
):
    trusted_decomposition_generate_plan = pipeline.add_query_step(
        "trusted_decomposition_benign_generate_plan",
        weak_model,
        get_decomposition,
        extract_decomposition,
        dependencies=[load_data],
        bon=5,
        temperature=0.3,
        team="blue_trusted_decomposition_benign",
    )

    trusted_decomposition_split_by_subtask = pipeline.add_transformation_step(
        "trusted_decomposition_benign_split_by_subtask",
        separate_subtasks,
        dependencies=[trusted_decomposition_generate_plan],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    trusted_decomposition_benign_implement_subtask = pipeline.add_query_step(
        f"trusted_decomposition_benign_implement_subtask_{weak_model}",
        strong_model,
        get_subtask_solution,
        extract_solution_in_block,
        dependencies=[trusted_decomposition_split_by_subtask],
        team="red_trusted_decomposition_benign",
    )

    trusted_decomposition_benign_subtask_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            trusted_decomposition_benign_implement_subtask,
            name=f"trusted_decomposition_benign_subtasks_{strong_model}",
            team_suffix="trusted_decomposition_benign",
        )
    )

    trusted_decomposition_benign_suspiciousness_pre_recompose = (
        pipeline.add_transformation_step(
            f"trusted_decomposition_benign_suspiciousness",
            join_subtasks(f"trusted_decomposition_benign_subtasks_{strong_model}"),
            dependencies=[trusted_decomposition_benign_subtask_suspiciousness],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    # TODO: Rewrite suspicious subtasks
    # This gives us the trusted edit equivalent for this strategy

    trusted_decomposition_benign_recompose_solutions = pipeline.add_query_step(
        f"trusted_decomposition_benign_recompose_subtask_{strong_model}",
        weak_model,
        get_recomposed_solution,
        extract_solution_in_block,
        dependencies=[trusted_decomposition_benign_suspiciousness_pre_recompose],
        team="blue_trusted_decomposition_benign",
    )

    # TODO: Check exploit input

    trusted_decomposition_benign_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            trusted_decomposition_benign_recompose_solutions,
            name=f"trusted_decomposition_benign_{strong_model}",
            team_suffix="trusted_decomposition_benign",
        )
    )

    trusted_decomposition_benign_usefulness = pipeline.add_code_evaluation_step(
        f"trusted_decomposition_benign_usefulness",
        f"{strong_model}+{weak_model}",
        run_solutions,
        dependencies=[trusted_decomposition_benign_recompose_solutions],
    )

    trusted_decomposition_benign_results = pipeline.add_transformation_step(
        "trusted_decomposition_benign_results",
        create_output(suspiciousness_key_suffix="trusted_decomposition_benign"),
        dependencies=[
            trusted_decomposition_benign_suspiciousness,
            trusted_decomposition_benign_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return (
        trusted_decomposition_benign_implement_subtask,
        trusted_decomposition_benign_results,
        None,
    )  # TODO: Add edited results


# TODO: Fix this
def trusted_decomposition_malicious_pipeline(
    pipeline, trusted_decomposition_benign_implement_subtask, strong_model, weak_model
):
    trusted_decomposition_malicious_backdoored_subtasks = (
        add_backdoor_steps_to_pipeline(
            pipeline,
            trusted_decomposition_benign_implement_subtask,
            strong_model,
            weak_model,
            suffix=f"trusted_decomposition_malicious",
            team_suffix="trusted_decomposition_malicious",
        )
    )

    # get best backdoors for each subproblem
    trusted_decomposition_malicious_subtasks = pipeline.add_transformation_step(
        f"trusted_decomposition_malicious_subtasks",
        select_backdoors,
        dependencies=[
            trusted_decomposition_benign_implement_subtask,
            trusted_decomposition_malicious_backdoored_subtasks,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    trusted_decomposition_malicious_subtask_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            trusted_decomposition_malicious_subtasks,
            name=f"trusted_decomposition_malicious_subtask_{strong_model}",
            team_suffix="trusted_decomposition_malicious",
        )
    )

    # TODO: Rewrite suspicious subtasks
    # This gives us the trusted edit equivalent for this strategy

    trusted_decomposition_malicious_suspiciousness = pipeline.add_transformation_step(
        f"trusted_decomposition_malicious_suspiciousness",
        join_subtasks(f"trusted_decomposition_malicious_subtask_{strong_model}"),
        dependencies=[trusted_decomposition_malicious_subtask_suspiciousness],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    # TODO: Check exploit validity

    trusted_decomposition_malicious_recomposed_solutions = pipeline.add_query_step(
        f"recompose_malicious_subtask_{strong_model}",
        weak_model,
        get_recomposed_solution,
        extract_solution_in_block,
        dependencies=[trusted_decomposition_malicious_suspiciousness],
        team="blue_trusted_decomposition_malicious",
    )

    trusted_decomposition_malicious_usefulness = pipeline.add_code_evaluation_step(
        f"trusted_decomposition_malicious_usefulness",
        f"{strong_model}+{weak_model}",
        run_solutions,
        dependencies=[trusted_decomposition_malicious_recomposed_solutions],
    )

    trusted_decomposition_malicious_results = pipeline.add_transformation_step(
        "trusted_decomposition_malicious_results",
        create_output(
            suspiciousness_key_suffix="trusted_decomposition_malicious",
            exploit_key_suffix="TODO: FIX ME",
        ),
        dependencies=[
            trusted_decomposition_malicious_suspiciousness,
            trusted_decomposition_malicious_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return trusted_decomposition_malicious_results, None  # TODO: Add edited results


def update_fields_for_decomposition(data):
    output = {}
    for problem_id, problem_data in data.items():
        output[problem_id] = problem_data
        output[problem_id]["original_backdoored_solution"] = problem_data["solution"]
    return output


def untrusted_decomposition_benign_pipeline(
    pipeline, query_benign_untrusted_get_solutions, strong_model, weak_model
):
    untrusted_decomposition_benign_downselect_problems = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_benign_downselect_problems",
            downselect_to_n_subproblems(5),
            dependencies=[query_benign_untrusted_get_solutions],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_benign_get_plan = pipeline.add_query_step(
        "untrusted_decomposition_benign_get_plan",
        strong_model,
        get_innocent_decomposition_improved,
        extract_decomposition,
        dependencies=[untrusted_decomposition_benign_downselect_problems],
        team="red_untrusted_decomposition_benign",
    )

    # Unedited plan
    untrusted_decomposition_benign_unedited_plan_solution = pipeline.add_query_step(
        "untrusted_decomposition_benign_unedited_plan_solution",
        weak_model,
        implement_solution_from_plan,
        extract_rewritten_solution_in_block,
        dependencies=[untrusted_decomposition_benign_get_plan],
        team="blue_untrusted_decomposition_benign",
    )

    untrusted_decomposition_benign_unedited_plan_improved_solution = (
        pipeline.add_query_step(
            "untrusted_decomposition_benign_unedited_plan_improved_solution",
            weak_model,
            improve_solution_from_plan,
            extract_improved_solution_in_block,
            dependencies=[untrusted_decomposition_benign_unedited_plan_solution],
            team="blue_untrusted_decomposition_benign",
        )
    )

    untrusted_decomposition_benign_unedited_plan_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_benign_unedited_plan_improved_solution,
            name="untrusted_decomposition_benign_unedited_plan",
            team_suffix="untrusted_decomposition_benign",
        )
    )

    untrusted_decomposition_benign_unedited_plan_usefulness = (
        pipeline.add_code_evaluation_step(
            f"untrusted_decomposition_benign_unedited_plan_usefulness",
            weak_model,
            run_solutions,
            dependencies=[
                untrusted_decomposition_benign_unedited_plan_improved_solution
            ],
        )
    )

    untrusted_decomposition_benign_unedited_plan_results = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_benign_unedited_plan_results",
            create_output(
                suspiciousness_key_suffix="untrusted_decomposition_benign_unedited_plan"
            ),
            dependencies=[
                untrusted_decomposition_benign_unedited_plan_suspiciousness,
                untrusted_decomposition_benign_unedited_plan_usefulness,
            ],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    # Edited plan
    untrusted_decomposition_benign_split_by_subtask = pipeline.add_transformation_step(
        "untrusted_decomposition_benign_split_by_subtask",
        separate_subtasks,
        dependencies=[untrusted_decomposition_benign_get_plan],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    untrusted_decomposition_benign_implement_subtask = pipeline.add_query_step(
        f"untrusted_decomposition_benign_implement_subtask_{strong_model}",
        weak_model,
        get_subtask_solution,
        extract_solution_in_block,
        dependencies=[untrusted_decomposition_benign_split_by_subtask],
        team="blue_untrusted_decomposition_benign",
    )

    untrusted_decomposition_benign_subtask_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_benign_implement_subtask,
            name="untrusted_decomposition_benign_subtask",
            team_suffix="untrusted_decomposition_benign",
        )
    )

    untrusted_decomposition_benign_join_subtasks = pipeline.add_transformation_step(
        "untrusted_decomposition_benign_join_subtasks",
        untrusted_join_subtasks("untrusted_decomposition_benign_subtask"),
        dependencies=[untrusted_decomposition_benign_subtask_suspiciousness],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    untrusted_decomposition_benign_filtered_decompositions = (
        pipeline.add_transformation_step(
            f"untrusted_decomposition_benign_filtered_decompositions",
            filter_most_suspicious_subtask,
            dependencies=[untrusted_decomposition_benign_join_subtasks],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_benign_edited_recomposed_solutions = (
        pipeline.add_query_step(
            f"untrusted_decomposition_benign_edited_recomposed_solutions",
            weak_model,
            implement_solution_from_plan,
            extract_rewritten_solution_in_block,
            dependencies=[untrusted_decomposition_benign_filtered_decompositions],
            team="blue_untrusted_decomposition_benign",
        )
    )

    untrusted_decomposition_benign_edited_improved_solution = pipeline.add_query_step(
        "untrusted_decomposition_benign_edited_improved_solution",
        weak_model,
        improve_solution_from_plan,
        extract_improved_solution_in_block,
        dependencies=[untrusted_decomposition_benign_edited_recomposed_solutions],
        team="blue_untrusted_decomposition_benign",
    )

    untrusted_decomposition_benign_edited_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_benign_edited_improved_solution,
            name="untrusted_decomposition_benign_edited",
            team_suffix="untrusted_decomposition_benign",
        )
    )

    untrusted_decomposition_benign_edited_usefulness = (
        pipeline.add_code_evaluation_step(
            f"untrusted_decomposition_benign_edited_usefulness",
            weak_model,
            run_solutions,
            dependencies=[untrusted_decomposition_benign_edited_improved_solution],
        )
    )

    untrusted_decomposition_benign_edited_results = pipeline.add_transformation_step(
        "untrusted_decomposition_benign_edited_results",
        create_edited_output(
            suspiciousness_key_suffix="untrusted_decomposition_benign_edited"
        ),
        dependencies=[
            untrusted_decomposition_benign_edited_suspiciousness,
            untrusted_decomposition_benign_edited_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    # Rewrite plan
    untrusted_decomposition_benign_rewrite_decomposition_step = pipeline.add_query_step(
        f"untrusted_decomposition_benign_rewrite_decomposition_{strong_model}",
        weak_model,
        rewrite_decomposition,
        extract_rewritten_decomposition,
        dependencies=[untrusted_decomposition_benign_get_plan],
        team="blue_untrusted_decomposition_benign",
    )

    untrusted_decomposition_benign_rewritten_plan_solution = pipeline.add_query_step(
        "untrusted_decomposition_benign_rewritten_plan_solution",
        weak_model,
        implement_solution_from_plan,
        extract_solution_in_block,
        dependencies=[untrusted_decomposition_benign_rewrite_decomposition_step],
        team="blue_untrusted_decomposition_benign",
    )

    untrusted_decomposition_benign_rewritten_improved_solution = (
        pipeline.add_query_step(
            "untrusted_decomposition_benign_rewritten_improved_solution",
            weak_model,
            improve_solution_from_plan,
            extract_improved_solution_in_block,
            dependencies=[untrusted_decomposition_benign_rewritten_plan_solution],
            team="blue_untrusted_decomposition_benign",
        )
    )

    untrusted_decomposition_benign_rewritten_plan_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_benign_rewritten_improved_solution,
            name="untrusted_decomposition_benign_rewritten",
        )
    )

    untrusted_decomposition_benign_rewritten_plan_usefulness = (
        pipeline.add_code_evaluation_step(
            f"untrusted_decomposition_benign_rewritten_plan_usefulness",
            weak_model,
            run_solutions,
            dependencies=[untrusted_decomposition_benign_rewritten_improved_solution],
        )
    )

    untrusted_decomposition_benign_rewritten_results = pipeline.add_transformation_step(
        "untrusted_decomposition_benign_rewritten_results",
        create_edited_output(
            suspiciousness_key_suffix="untrusted_decomposition_benign_rewritten"
        ),
        dependencies=[
            untrusted_decomposition_benign_rewritten_plan_suspiciousness,
            untrusted_decomposition_benign_rewritten_plan_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return (
        untrusted_decomposition_benign_unedited_plan_results,
        untrusted_decomposition_benign_edited_results,
        untrusted_decomposition_benign_rewritten_results,
    )


def untrusted_decomposition_malicious_pipeline(
    pipeline, query_malicious_flatten_backdoors, strong_model, weak_model
):
    untrusted_decomposition_malicious_downselect_problems = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_downselect_problems",
            downselect_to_n_subproblems(5),
            dependencies=[query_malicious_flatten_backdoors],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_malicious_update_fields = pipeline.add_transformation_step(
        "untrusted_decomposition_malicious_update_fields",
        update_fields_for_decomposition,
        dependencies=[untrusted_decomposition_malicious_downselect_problems],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    # Red Team:
    # * Generate plan for solution from code
    untrusted_decomposition_malicious_get_plan = pipeline.add_query_step(
        "untrusted_decomposition_malicious_get_plan",
        strong_model,
        get_backdoor_decomposition_improved,
        extract_decomposition,
        dependencies=[untrusted_decomposition_malicious_update_fields],
        team="red_untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_improve_plan = pipeline.add_query_step(
        "untrusted_decomposition_malicious_improve_plan",
        strong_model,
        improve_backdoor_decomposition,
        extract_decomposition,
        dependencies=[untrusted_decomposition_malicious_get_plan],
        team="red_untrusted_decomposition_malicious",
    )

    # Blue Team:
    # * Generate solution from unedited plan
    # * Check suspiciousness of solution
    # * Edit the plan to remove suspected backdoors
    # * Check suspiciousness of edited plan
    # * Implement exploit check for unedited plan's solution
    # * Implement exploit check for edited plan's solution

    untrusted_decomposition_malicious_unedited_plan_solution = pipeline.add_query_step(
        "untrusted_decomposition_malicious_unedited_plan_solution",
        weak_model,
        implement_solution_from_plan,
        extract_rewritten_solution_in_block,
        dependencies=[untrusted_decomposition_malicious_improve_plan],
        team="blue_untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_unedited_plan_improved_solution = (
        pipeline.add_query_step(
            "untrusted_decomposition_malicious_unedited_plan_improved_solution",
            weak_model,
            improve_solution_from_plan,
            extract_improved_solution_in_block,
            dependencies=[untrusted_decomposition_malicious_unedited_plan_solution],
            team="blue_untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_unedited_plan_exploit_check = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_unedited_plan_exploit_check",
            evaluate_exploit_validity(
                pipeline.config.language,
                suffix="query_malicious",
                use_apps_as_ground_truth=True,
                include_pre_rewrite=False,
            ),
            dependencies=[
                untrusted_decomposition_malicious_unedited_plan_improved_solution
            ],
            strong_model=strong_model,
            weak_model=weak_model,
            read_cache=True,
        )
    )

    untrusted_decomposition_malicious_unedited_plan_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_malicious_unedited_plan_exploit_check,
            name="untrusted_decomposition_malicious_unedited_plan",
            team_suffix="untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_unedited_plan_usefulness = (
        pipeline.add_code_evaluation_step(
            f"untrusted_decomposition_malicious_unedited_plan_usefulness",
            weak_model,
            run_solutions,
            dependencies=[
                untrusted_decomposition_malicious_unedited_plan_exploit_check
            ],
        )
    )

    untrusted_decomposition_malicious_unedited_plan_retry_exploit = attempt_exploit(
        pipeline,
        strong_model,
        weak_model,
        untrusted_decomposition_malicious_unedited_plan_improved_solution,
        use_modified_solution=False,
        suffix="untrusted_decomposition_malicious_unedited",
        team_suffix="untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_rejoin_unedited_exploit_attempts = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rejoin_unedited_exploit_attempts",
            rejoin_exploit_attempts(
                suffix="untrusted_decomposition_malicious_unedited"
            ),
            dependencies=[
                untrusted_decomposition_malicious_unedited_plan_improved_solution,
                untrusted_decomposition_malicious_unedited_plan_retry_exploit,
            ],
        )
    )

    untrusted_decomposition_malicious_unedited_exploit_retry_check = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_unedited_exploit_retry_check",
            evaluate_exploit_validity(
                pipeline.config.language,
                suffix="untrusted_decomposition_malicious_unedited",
                include_pre_rewrite=False,
            ),
            dependencies=[
                untrusted_decomposition_malicious_rejoin_unedited_exploit_attempts
            ],
            strong_model=strong_model,
            weak_model=weak_model,
            read_cache=True,
        )
    )

    untrusted_decomposition_malicious_unedited_plan_initial_results = pipeline.add_transformation_step(
        "untrusted_decomposition_malicious_unedited_plan_initial_results",
        create_output(
            suspiciousness_key_suffix="untrusted_decomposition_malicious_unedited_plan",
            exploit_key_suffix="query_malicious",
        ),
        dependencies=[
            untrusted_decomposition_malicious_unedited_plan_suspiciousness,
            untrusted_decomposition_malicious_unedited_plan_usefulness,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    untrusted_decomposition_malicious_unedited_plan_results = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_unedited_plan_results",
            add_exploit_retry_data_to_output(
                key_suffix="untrusted_decomposition_malicious_unedited"
            ),
            dependencies=[
                untrusted_decomposition_malicious_unedited_plan_initial_results,
                untrusted_decomposition_malicious_unedited_exploit_retry_check,
            ],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    # Edit solution
    untrusted_decomposition_malicious_split_by_subtask = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_split_by_subtask",
            separate_subtasks,
            dependencies=[untrusted_decomposition_malicious_improve_plan],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_malicious_implement_subtask = pipeline.add_query_step(
        f"untrusted_decomposition_malicious_implement_subtask_{strong_model}",
        weak_model,
        get_subtask_solution,
        extract_solution_in_block,
        dependencies=[untrusted_decomposition_malicious_split_by_subtask],
        team="blue_untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_subtask_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_malicious_implement_subtask,
            name="untrusted_decomposition_malicious_subtask",
            team_suffix="untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_join_subtasks = pipeline.add_transformation_step(
        "untrusted_decomposition_malicious_join_subtasks",
        untrusted_join_subtasks("untrusted_decomposition_malicious_subtask"),
        dependencies=[untrusted_decomposition_malicious_subtask_suspiciousness],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    untrusted_decomposition_malicious_filtered_decompositions = (
        pipeline.add_transformation_step(
            f"untrusted_decomposition_malicious_filtered_decompositions",
            filter_most_suspicious_subtask,
            dependencies=[untrusted_decomposition_malicious_join_subtasks],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_malicious_edited_plan_solution = pipeline.add_query_step(
        f"untrusted_decomposition_malicious_edited_plan_solution",
        weak_model,
        implement_solution_from_plan,
        extract_rewritten_solution_in_block,
        dependencies=[untrusted_decomposition_malicious_filtered_decompositions],
        team="blue_untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_edited_plan_improved_solution = (
        pipeline.add_query_step(
            "untrusted_decomposition_malicious_edited_plan_improved_solution",
            weak_model,
            improve_solution_from_plan,
            extract_improved_solution_in_block,
            dependencies=[untrusted_decomposition_malicious_edited_plan_solution],
            team="blue_untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_edited_exploit_check = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_edited_exploit_check",
            evaluate_exploit_validity(
                pipeline.config.language,
                suffix="query_malicious",
                use_apps_as_ground_truth=True,
                include_pre_rewrite=False,
            ),
            dependencies=[
                untrusted_decomposition_malicious_edited_plan_improved_solution
            ],
            strong_model=strong_model,
            weak_model=weak_model,
            read_cache=True,
        )
    )

    untrusted_decomposition_malicious_edited_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_malicious_edited_exploit_check,
            name="untrusted_decomposition_malicious_edited",
            team_suffix="untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_edited_usefulness = (
        pipeline.add_code_evaluation_step(
            f"untrusted_decomposition_malicious_edited_usefulness",
            weak_model,
            run_solutions,
            dependencies=[untrusted_decomposition_malicious_edited_exploit_check],
        )
    )

    untrusted_decomposition_malicious_edited_plan_retry_exploit = attempt_exploit(
        pipeline,
        strong_model,
        weak_model,
        untrusted_decomposition_malicious_edited_plan_improved_solution,
        use_modified_solution=False,
        suffix="untrusted_decomposition_malicious_edited",
        team_suffix="untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_rejoin_edited_exploit_attempts = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rejoin_edited_exploit_attempts",
            rejoin_exploit_attempts(suffix="untrusted_decomposition_malicious_edited"),
            dependencies=[
                untrusted_decomposition_malicious_edited_plan_improved_solution,
                untrusted_decomposition_malicious_edited_plan_retry_exploit,
            ],
        )
    )

    untrusted_decomposition_malicious_edited_exploit_retry_check = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_edited_exploit_retry_check",
            evaluate_exploit_validity(
                pipeline.config.language,
                suffix="untrusted_decomposition_malicious_edited",
                use_apps_as_ground_truth=True,
                include_pre_rewrite=False,
            ),
            dependencies=[
                untrusted_decomposition_malicious_rejoin_edited_exploit_attempts
            ],
            strong_model=strong_model,
            weak_model=weak_model,
            read_cache=True,
        )
    )

    untrusted_decomposition_malicious_edited_initial_results = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_edited_initial_results",
            create_edited_output(
                suspiciousness_key_suffix="untrusted_decomposition_malicious_edited",
                exploit_key_suffix="query_malicious",
            ),
            dependencies=[
                untrusted_decomposition_malicious_edited_suspiciousness,
                untrusted_decomposition_malicious_edited_usefulness,
            ],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_malicious_edited_plan_results = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_edited_plan_results",
            add_exploit_retry_data_to_output(
                key_suffix="untrusted_decomposition_malicious_edited"
            ),
            dependencies=[
                untrusted_decomposition_malicious_edited_initial_results,
                untrusted_decomposition_malicious_edited_exploit_retry_check,
            ],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    # Rewrite plan
    untrusted_decomposition_malicious_rewrite_decomposition_step = (
        pipeline.add_query_step(
            f"untrusted_decomposition_malicious_rewrite_decomposition_{strong_model}",
            weak_model,
            rewrite_decomposition,
            extract_rewritten_decomposition,
            dependencies=[untrusted_decomposition_malicious_improve_plan],
            team="blue_untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_rewritten_plan_solution = pipeline.add_query_step(
        "untrusted_decomposition_malicious_rewritten_plan_solution",
        weak_model,
        implement_solution_from_plan,
        extract_solution_in_block,
        dependencies=[untrusted_decomposition_malicious_rewrite_decomposition_step],
        team="blue_untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_rewritten_improved_solution = (
        pipeline.add_query_step(
            "untrusted_decomposition_malicious_rewritten_improved_solution",
            weak_model,
            improve_solution_from_plan,
            extract_improved_solution_in_block,
            dependencies=[untrusted_decomposition_malicious_rewritten_plan_solution],
            team="blue_untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_rewritten_plan_exploit_check = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rewritten_plan_exploit_check",
            evaluate_exploit_validity(
                pipeline.config.language,
                suffix="query_malicious",
                use_apps_as_ground_truth=True,
                include_pre_rewrite=False,
            ),
            dependencies=[
                untrusted_decomposition_malicious_rewritten_improved_solution
            ],
            strong_model=strong_model,
            weak_model=weak_model,
            read_cache=True,
        )
    )

    untrusted_decomposition_malicious_rewritten_plan_suspiciousness = (
        add_basic_blue_team_monitor_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            untrusted_decomposition_malicious_rewritten_plan_exploit_check,
            name="untrusted_decomposition_malicious_rewritten",
            team_suffix="untrusted_decomposition_malicious",
        )
    )

    untrusted_decomposition_malicious_rewritten_plan_usefulness = (
        pipeline.add_code_evaluation_step(
            f"untrusted_decomposition_malicious_rewritten_plan_usefulness",
            weak_model,
            run_solutions,
            dependencies=[
                untrusted_decomposition_malicious_rewritten_plan_exploit_check
            ],
        )
    )

    untrusted_decomposition_malicious_rewritten_plan_retry_exploit = attempt_exploit(
        pipeline,
        strong_model,
        weak_model,
        untrusted_decomposition_malicious_rewritten_improved_solution,
        use_modified_solution=False,
        suffix="untrusted_decomposition_malicious_rewritten",
        team_suffix="untrusted_decomposition_malicious",
    )

    untrusted_decomposition_malicious_rejoin_rewritten_exploit_attempts = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rejoin_rewritten_exploit_attempts",
            rejoin_exploit_attempts(
                suffix="untrusted_decomposition_malicious_rewritten"
            ),
            dependencies=[
                untrusted_decomposition_malicious_rewritten_improved_solution,
                untrusted_decomposition_malicious_rewritten_plan_retry_exploit,
            ],
        )
    )

    untrusted_decomposition_malicious_rewritten_exploit_retry_check = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rewritten_exploit_retry_check",
            evaluate_exploit_validity(
                pipeline.config.language,
                suffix="untrusted_decomposition_malicious_rewritten",
                use_apps_as_ground_truth=True,
                include_pre_rewrite=False,
            ),
            dependencies=[
                untrusted_decomposition_malicious_rejoin_rewritten_exploit_attempts
            ],
            strong_model=strong_model,
            weak_model=weak_model,
            read_cache=True,
        )
    )

    untrusted_decomposition_malicious_rewritten_initial_results = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rewritten_initial_results",
            create_edited_output(
                suspiciousness_key_suffix="untrusted_decomposition_malicious_rewritten",
                exploit_key_suffix="query_malicious",
            ),
            dependencies=[
                untrusted_decomposition_malicious_rewritten_plan_suspiciousness,
                untrusted_decomposition_malicious_rewritten_plan_usefulness,
            ],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    untrusted_decomposition_malicious_rewritten_plan_results = (
        pipeline.add_transformation_step(
            "untrusted_decomposition_malicious_rewritten_plan_results",
            add_exploit_retry_data_to_output(
                key_suffix="untrusted_decomposition_malicious_rewritten"
            ),
            dependencies=[
                untrusted_decomposition_malicious_rewritten_initial_results,
                untrusted_decomposition_malicious_rewritten_exploit_retry_check,
            ],
            strong_model=strong_model,
            weak_model=weak_model,
        )
    )

    return (
        untrusted_decomposition_malicious_unedited_plan_results,
        untrusted_decomposition_malicious_edited_plan_results,
        untrusted_decomposition_malicious_rewritten_plan_results,
    )


def get_pipeline(
    strong_model,
    weak_model,
    name=None,
    use_cache=True,
    num_problems=None,
    enable_decomposition=True,
    problem_ids=None,
):
    ROOT_DIR = get_root_directory()
    DATA_DIR = ROOT_DIR / "data" / "APPS"
    TEST_DIR = DATA_DIR / "test"

    if name is None:
        pipeline_name = "data_collection"
    else:
        pipeline_name = f"data_collection-{name}"

    pipeline_config = PipelineConfig(
        pipeline_name,
        anthropic_num_threads=40,
        openai_fraction_rate_limit=0.99,
        num_problems=num_problems,
        use_cache=use_cache,
    )

    pipeline = Pipeline(pipeline_config)
    load_data = pipeline.add_load_data_step(
        "load_data",
        load_problem_subset(
            "INTERVIEW", require_solutions=True, problem_ids=problem_ids
        ),
        TEST_DIR,
    )

    query_benign_untrusted_get_solutions, query_benign_results = query_benign_pipeline(
        pipeline, load_data, strong_model, weak_model
    )
    (
        query_malicious_flatten_backdoors,
        query_malicious_results,
    ) = query_malicious_pipeline(
        pipeline, query_benign_untrusted_get_solutions, strong_model, weak_model
    )
    query_trusted_solutions, query_trusted_results = query_trusted_pipeline(
        pipeline, load_data, strong_model, weak_model
    )
    get_avg_difficulty_results = get_problem_difficulty(
        pipeline, query_trusted_solutions, strong_model, weak_model
    )
    query_benign_rewrite_results = query_benign_edited_pipeline(
        pipeline, query_benign_untrusted_get_solutions, strong_model, weak_model
    )
    query_malicious_rewrite_results = query_malicious_edited_pipeline(
        pipeline, query_malicious_flatten_backdoors, strong_model, weak_model
    )

    if enable_decomposition:
        # TODO: Fix these
        # (
        #     implement_subtask,
        #     trusted_decomposition_benign_results,
        #     trusted_decomposition_benign_edited_results,
        # ) = trusted_decomposition_benign_pipeline(
        #     pipeline, load_data, strong_model, weak_model
        # )
        # (
        #     trusted_decomposition_malicious_results,
        #     trusted_decomposition_malicious_edited_results,
        # ) = trusted_decomposition_malicious_pipeline(
        #     pipeline, implement_subtask, strong_model, weak_model
        # )
        # end TODO
        (
            untrusted_decomp_benign_results,
            untrusted_decomp_benign_edited_results,
            untrusted_decomp_benign_rewritten_results,
        ) = untrusted_decomposition_benign_pipeline(
            pipeline, query_benign_untrusted_get_solutions, strong_model, weak_model
        )
        (
            untrusted_decomp_malicious_results,
            untrusted_decomp_malicious_edited_results,
            untrusted_decomp_malicious_rewritten_results,
        ) = untrusted_decomposition_malicious_pipeline(
            pipeline, query_malicious_flatten_backdoors, strong_model, weak_model
        )

    ############################# Combine Results #############################
    dependencies = [
        get_avg_difficulty_results,
        query_benign_results,
        query_malicious_results,
        query_trusted_results,
        query_benign_rewrite_results,
        query_malicious_rewrite_results,
    ]
    if enable_decomposition:
        dependencies.extend(
            [
                # trusted_decomposition_benign_results,
                # trusted_decomposition_malicious_results,
                untrusted_decomp_benign_results,
                untrusted_decomp_benign_edited_results,
                untrusted_decomp_benign_rewritten_results,
                untrusted_decomp_malicious_results,
                untrusted_decomp_malicious_edited_results,
                untrusted_decomp_malicious_rewritten_results,
            ]
        )
    final_results = pipeline.add_transformation_step(
        "final_results",
        join_output,
        dependencies=dependencies,
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return pipeline


if __name__ == "__main__":
    from core.utils import setup_logging

    setup_logging("debug")

    strong_model = "gpt-4-0125-preview"
    weak_model = "gpt-3.5-turbo-1106"

    data_pipeline = get_pipeline(
        strong_model=strong_model,
        weak_model=weak_model,
        name="add-difficulty",
        use_cache=True,
        num_problems=1,
    )
    results = asyncio.run(data_pipeline.run())
