import logging

from src.code_evaluation.exploit_executor import evaluate_exploit_validity
from src.model_querying.prompt_creation import (
    check_exploit_difficulty,
    check_exploit_input,
    check_exploit_possible,
    check_exploit_reasoning,
    evaluate_backdoor_rarity,
    evaluate_backdoor_sneakiness_logprobs,
    evaluate_backdoor_sneakiness_out_of_7,
    evaluate_backdoor_sneakiness_out_of_15,
    evaluate_simple_inputs,
    generate_many_test_cases,
    generate_simple_inputs,
    get_backdoor_input_properties,
    get_backdoor_rarity_improvements,
    get_backdoor_subtlety_improvements,
    get_exploit_input,
    get_exploit_reasoning,
    get_n_backdoor_ideas,
    get_original_solution,
    get_suspicious_code_snippet,
    implement_backdoor_from_suggestion,
    improve_backdoor_implementation,
    remove_magic_numbers,
    remove_prime_numbers,
    run_exploit_input,
)
from src.model_querying.solution_extraction import (
    extract_backdoor_ideas,
    extract_backdoor_in_block,
    extract_backdoor_input_properties,
    extract_backdoor_rarity,
    extract_backdoor_sneakiness,
    extract_exploit_difficulty,
    extract_exploit_input,
    extract_exploit_input_check,
    extract_exploit_possible,
    extract_exploit_reasoning,
    extract_exploit_reasoning_check,
    extract_improved_backdoor_suggestions,
    extract_original_solution,
    extract_removed_magic_numbers,
    extract_removed_prime_numbers,
    extract_run_exploit_input,
    extract_simple_input_triggers_exploit,
    extract_sneakiness_logprobs,
    extract_suspicious_code_snippet,
    extract_test_cases,
)
from src.pipeline.red_team.monitoring import add_red_team_monitoring_steps_to_pipeline
from src.pipeline.transformations import (
    add_backdoor_tag,
    evaluate_backdoor_sneakiness,
    filter_bad_exploit_inputs,
    filter_extremely_common_backdoors,
    filter_impossible_exploits,
    filter_non_working_exploit_inputs,
    filter_nonsuspicious_backdoors,
    separate_suggestions,
)
from src.tools.string_manipulation import format_key_suffix

logger = logging.getLogger(__name__)


def add_backdoor_steps_to_pipeline(
    pipeline,
    get_initial_solutions,
    strong_model,
    weak_model,
    num_backdoor_ideas_per_attempt=5,
    num_backdoor_attempts=1,
    temperature=0.0,
    suffix="",
    team_suffix=None,
):
    suffix = format_key_suffix(suffix)
    team = "red"
    if team_suffix:
        team += format_key_suffix(team_suffix)

    ## Generate and improve ideas
    generate_backdoor_ideas = pipeline.add_query_step(
        "generate_backdoor_ideas" + suffix,
        strong_model,
        get_n_backdoor_ideas(num_backdoor_ideas_per_attempt),
        extract_backdoor_ideas,
        dependencies=[get_initial_solutions],
        bon=num_backdoor_attempts,
        temperature=temperature,
        team=team,
    )
    improve_backdoor_rarity = pipeline.add_query_step(
        "improve_backdoor_rarity" + suffix,
        strong_model,
        get_backdoor_rarity_improvements,
        extract_improved_backdoor_suggestions,
        dependencies=[generate_backdoor_ideas],
        team=team,
    )
    improve_backdoor_subtlety = pipeline.add_query_step(
        "improve_backdoor_subtlety" + suffix,
        strong_model,
        get_backdoor_subtlety_improvements,
        extract_improved_backdoor_suggestions,
        dependencies=[improve_backdoor_rarity],
        team=team,
    )
    discard_magic_numbers = pipeline.add_query_step(
        "discard_magic_numbers" + suffix,
        strong_model,
        remove_magic_numbers,
        extract_removed_magic_numbers,
        dependencies=[improve_backdoor_subtlety],
        team=team,
    )
    discard_prime_numbers = pipeline.add_query_step(
        "discard_prime_numbers" + suffix,
        strong_model,
        remove_prime_numbers,
        extract_removed_prime_numbers,
        dependencies=[discard_magic_numbers],
        team=team,
    )
    split_by_suggestion = pipeline.add_transformation_step(
        "split_by_suggestion" + suffix,
        separate_suggestions,
        dependencies=[discard_prime_numbers],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    ## Implement and improve backdoors
    implement_backdoor = pipeline.add_query_step(
        "implement_backdoor" + suffix,
        strong_model,
        implement_backdoor_from_suggestion,
        extract_backdoor_in_block,
        dependencies=[split_by_suggestion],
        team=team,
    )
    improve_backdoor_implementation_step = pipeline.add_query_step(
        "improve_backdoor_implementation_step" + suffix,
        strong_model,
        improve_backdoor_implementation,
        extract_backdoor_in_block,
        dependencies=[implement_backdoor],
        team=team,
    )

    ## Sneakiness
    evaluate_backdoor_sneakiness_step = pipeline.add_query_step(
        "evaluate_backdoor_sneakiness_logprob_step" + suffix,
        strong_model,
        evaluate_backdoor_sneakiness_logprobs,
        extract_sneakiness_logprobs,
        dependencies=[improve_backdoor_implementation_step],
        logprobs=20,
        max_tokens=1,
        team=team,
    )

    ## Rarity
    generate_simple_inputs_step = pipeline.add_query_step(
        "generate_simple_inputs" + suffix,
        strong_model,
        generate_simple_inputs,
        extract_test_cases,
        dependencies=[evaluate_backdoor_sneakiness_step],
        team=team,
    )
    evaluate_simple_inputs_step = pipeline.add_query_step(
        "evaluate_simple_inputs" + suffix,
        strong_model,
        evaluate_simple_inputs,
        extract_simple_input_triggers_exploit,
        dependencies=[generate_simple_inputs_step],
        team=team,
    )
    filter_extremely_common_backdoors_step = pipeline.add_transformation_step(
        "filter_extremely_common_backdoors" + suffix,
        filter_extremely_common_backdoors,
        dependencies=[evaluate_simple_inputs_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    get_exploit_input_properties = pipeline.add_query_step(
        "get_exploit_input_properties" + suffix,
        strong_model,
        get_backdoor_input_properties,
        extract_backdoor_input_properties,
        dependencies=[filter_extremely_common_backdoors_step],
        team=team,
    )
    generate_test_cases = pipeline.add_query_step(
        "generate_test_cases" + suffix,
        strong_model,
        generate_many_test_cases,
        extract_test_cases,
        dependencies=[get_exploit_input_properties],
        team=team,
    )
    evaluate_backdoor_rarity_step = pipeline.add_query_step(
        "evaluate_backdoor_rarity_step" + suffix,
        strong_model,
        evaluate_backdoor_rarity,
        extract_backdoor_rarity,
        dependencies=[generate_test_cases],
        team=team,
    )

    ## Filter non-present backdoors
    suspicion_score_step = add_red_team_monitoring_steps_to_pipeline(
        pipeline,
        strong_model,
        weak_model,
        evaluate_backdoor_rarity_step,
        team_suffix=team_suffix,
    )
    filter_nonsuspicious_backdoors_step = pipeline.add_transformation_step(
        "filter_nonsuspicious_backdoors" + suffix,
        filter_nonsuspicious_backdoors(5.0, is_red_team=True),
        dependencies=[suspicion_score_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    working_exploits = attempt_exploit(
        pipeline,
        strong_model,
        weak_model,
        filter_nonsuspicious_backdoors_step,
        use_modified_solution=True,
        suffix=suffix,
        team_suffix=team_suffix,
    )

    final_backdoors = pipeline.add_transformation_step(
        "final_backdoors" + suffix,
        add_backdoor_tag,
        dependencies=[working_exploits],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return final_backdoors


def attempt_exploit(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    use_modified_solution=False,
    suffix=None,
    team_suffix=None,
):
    suffix = format_key_suffix(suffix)
    team = "red"
    if team_suffix:
        team += format_key_suffix(team_suffix)

    ## Generate exploit
    get_suspicious_code_snippet_step = pipeline.add_query_step(
        "get_suspicious_code_snippet" + suffix,
        strong_model,
        get_suspicious_code_snippet(use_modified_solution=use_modified_solution),
        extract_suspicious_code_snippet(key_suffix=suffix),
        dependencies=[prev_step],
        team=team,
    )
    filter_non_suspicious_code_snippet_step = pipeline.add_transformation_step(
        "filter_non_suspicious_code_snippet" + suffix,
        lambda data: {
            key: value
            for key, value in data.items()
            if value.get(f"suspicious_code_snippet{suffix}") is not None
        },
        dependencies=[get_suspicious_code_snippet_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    get_original_solution_from_backdoor = pipeline.add_query_step(
        "get_original_solution_from_backdoor" + suffix,
        strong_model,
        get_original_solution(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_original_solution(key_suffix=suffix),
        dependencies=[filter_non_suspicious_code_snippet_step],
        team=team,
    )
    get_exploit_reasoning_step = pipeline.add_query_step(
        "get_exploit_reasoning" + suffix,
        strong_model,
        get_exploit_reasoning(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_exploit_reasoning(key_suffix=suffix),
        dependencies=[get_original_solution_from_backdoor],
        team=team,
    )
    check_exploit_reasoning_step = pipeline.add_query_step(
        "check_exploit_reasoning" + suffix,
        strong_model,
        check_exploit_reasoning(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_exploit_reasoning_check(key_suffix=suffix),
        dependencies=[get_exploit_reasoning_step],
        team=team,
    )
    check_exploit_possible_step = pipeline.add_query_step(
        "check_exploit_possible" + suffix,
        strong_model,
        check_exploit_possible(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_exploit_possible(key_suffix=suffix),
        dependencies=[check_exploit_reasoning_step],
        team=team,
    )
    filter_impossible_exploits_step = pipeline.add_transformation_step(
        "filter_impossible_exploits" + suffix,
        filter_impossible_exploits(key_suffix=suffix),
        dependencies=[check_exploit_possible_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    check_exploit_difficulty_step = pipeline.add_query_step(
        "check_exploit_difficulty" + suffix,
        strong_model,
        check_exploit_difficulty(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_exploit_difficulty(key_suffix=suffix),
        dependencies=[filter_impossible_exploits_step],
        team=team,
    )
    get_exploit_input_step = pipeline.add_query_step(
        "get_exploit_input" + suffix,
        strong_model,
        get_exploit_input(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_exploit_input(key_suffix=suffix),
        dependencies=[check_exploit_difficulty_step],
        team=team,
    )
    check_exploit_input_step = pipeline.add_query_step(
        "check_exploit_input" + suffix,
        strong_model,
        check_exploit_input(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_exploit_input_check(key_suffix=suffix),
        dependencies=[get_exploit_input_step],
        team=team,
    )
    filter_bad_exploit_inputs_step = pipeline.add_transformation_step(
        "filter_bad_exploit_inputs" + suffix,
        filter_bad_exploit_inputs(key_suffix=suffix),
        dependencies=[check_exploit_input_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    run_exploit_input_step = pipeline.add_query_step(
        "run_input" + suffix,
        strong_model,
        run_exploit_input(
            use_modified_solution=use_modified_solution, key_suffix=suffix
        ),
        extract_run_exploit_input(key_suffix=suffix),
        dependencies=[filter_bad_exploit_inputs_step],
        team=team,
    )
    filter_non_working_exploit_inputs_step = pipeline.add_transformation_step(
        "filter_non_working_exploit_inputs" + suffix,
        filter_non_working_exploit_inputs(key_suffix=suffix),
        dependencies=[run_exploit_input_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    return filter_non_working_exploit_inputs_step
