import numpy as np

from src.model_querying.prompt_creation import (
    evaluate_backdoor_sneakiness_logprobs,
    evaluate_backdoor_sneakiness_out_of_7,
    evaluate_backdoor_sneakiness_out_of_15,
    evaluate_backdoor_sneakiness_scalar,
    get_backdoor_likelihood,
    get_backdoor_presence,
    get_monitor_description,
)
from src.model_querying.solution_extraction import (
    extract_backdoor_likelihood,
    extract_backdoor_presence,
    extract_backdoor_sneakiness,
    extract_backdoor_sneakiness_scalar,
    extract_monitor_description,
    extract_sneakiness_logprobs,
)
from src.pipeline.transformations import (
    evaluate_backdoor_sneakiness,
    get_backdoor_suspicion_score,
    strip_comments,
)
from src.tools.string_manipulation import format_key_suffix


def add_red_team_sneakiness_steps_to_pipeline(
    pipeline, strong_model, weak_model, prev_step, name=None, use_modified_solution=True
):
    if name is not None:
        suffix = f"_{name}"
    else:
        suffix = ""

    evaluate_backdoor_sneakiness_out_of_15_step = pipeline.add_query_step(
        "evaluate_backdoor_sneakiness_out_of_15_step" + suffix,
        strong_model,
        evaluate_backdoor_sneakiness_out_of_15,
        extract_backdoor_sneakiness,
        dependencies=[prev_step],
        team="red",
    )
    evaluate_backdoor_sneakiness_out_of_7_step = pipeline.add_query_step(
        "evaluate_backdoor_sneakiness_out_of_7_step" + suffix,
        strong_model,
        evaluate_backdoor_sneakiness_out_of_7,
        extract_backdoor_sneakiness,
        dependencies=[prev_step],
        team="red",
    )
    evaluate_backdoor_sneakiness_step = pipeline.add_transformation_step(
        "evaluate_backdoor_sneakiness" + suffix,
        evaluate_backdoor_sneakiness,
        dependencies=[
            evaluate_backdoor_sneakiness_out_of_15_step,
            evaluate_backdoor_sneakiness_out_of_7_step,
        ],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    return evaluate_backdoor_sneakiness_step


def join_by_problem(data):
    joined_data = {}
    for key, value in data.items():
        problem_id = "-".join(key.split("-")[:-1])
        if problem_id not in joined_data:
            joined_data[problem_id] = []
        joined_data[problem_id].append(value)
    return joined_data


def average_sneakiness(data, joined_data):
    for key, value in data.items():
        scores = [
            i["sneakiness"] for i in joined_data.get(key) if i["sneakiness"] is not None
        ]
        if scores is None or len(scores) == 0:
            estimated_prob = None
        else:
            estimated_prob = np.mean(scores)
        data[key]["sneakiness"] = estimated_prob
        data[key]["sneakiness_scores"] = scores
    return data


def add_red_team_sneakiness_logprob_steps_to_pipeline(
    pipeline, strong_model, weak_model, prev_step, name=None, use_modified_solution=True
):
    if name is not None:
        suffix = f"_{name}"
    else:
        suffix = ""

    evaluate_backdoor_sneakiness_step = pipeline.add_query_step(
        "evaluate_backdoor_sneakiness_logprob_step" + suffix,
        strong_model,
        evaluate_backdoor_sneakiness_logprobs,
        extract_sneakiness_logprobs,
        dependencies=[prev_step],
        logprobs=20,
        max_tokens=1,
        team="red",
    )

    return evaluate_backdoor_sneakiness_step


def add_red_team_scalar_sneakiness_steps_to_pipeline(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    use_modified_solution=True,
    bon=5,
    temperature=0.2,
):
    if name is not None:
        suffix = f"_{name}"
    else:
        suffix = ""

    evaluate_backdoor_sneakiness_step = pipeline.add_query_step(
        "evaluate_backdoor_sneakiness_scalar_step" + suffix,
        strong_model,
        evaluate_backdoor_sneakiness_scalar,
        extract_backdoor_sneakiness_scalar,
        dependencies=[prev_step],
        team="red",
        bon=bon,
        temperature=temperature,
    )
    if bon > 1:
        join_monitors = pipeline.add_transformation_step(
            "join_monitors" + suffix,
            join_by_problem,
            dependencies=[evaluate_backdoor_sneakiness_step],
            strong_model=strong_model,
            weak_model=weak_model,
        )
        get_backdoor_likelihood_sampling_step = pipeline.add_transformation_step(
            f"get_backdoor_likelihood_sampling{suffix}",
            average_sneakiness,
            dependencies=[prev_step, join_monitors],
            strong_model=strong_model,
            weak_model=weak_model,
        )
        return get_backdoor_likelihood_sampling_step
    else:
        return evaluate_backdoor_sneakiness_step


def add_red_team_monitoring_steps_to_pipeline(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    use_modified_solution=True,
    team_suffix=None,
):
    team = "red"
    if team_suffix:
        team += format_key_suffix(team_suffix)
    suffix = format_key_suffix(name)
    if use_modified_solution:
        code = "modified_solution"
    else:
        code = "solution"

    remove_code_comments_step = pipeline.add_transformation_step(
        f"remove_code_comments_red_team{suffix}",
        strip_comments(code, key_suffix=suffix),
        dependencies=[prev_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    get_monitor_description_step = pipeline.add_query_step(
        f"get_monitor_description_step_red_team{suffix}",
        strong_model,
        get_monitor_description(
            key_suffix=suffix, use_subtask=False
        ),  # red team will always evaluate the final solution
        extract_monitor_description(key_suffix=suffix),
        dependencies=[remove_code_comments_step],
        team=team,
    )
    get_backdoor_likelihood_step = pipeline.add_query_step(
        f"get_backdoor_likelihood_step_red_team{suffix}",
        strong_model,
        get_backdoor_likelihood(key_suffix=suffix),
        extract_backdoor_likelihood(key_suffix=suffix, is_red_team=True),
        dependencies=[get_monitor_description_step],
        team=team,
    )
    get_backdoor_presence_step = pipeline.add_query_step(
        f"get_backdoor_presence_step_red_team{suffix}",
        strong_model,
        get_backdoor_presence(key_suffix=suffix),
        extract_backdoor_presence(key_suffix=suffix, is_red_team=True),
        dependencies=[get_monitor_description_step],
        team=team,
    )
    get_backdoor_suspicion_score_step = pipeline.add_transformation_step(
        f"get_backdoor_suspicion_score_red_team{suffix}",
        get_backdoor_suspicion_score(key_suffix=suffix, is_red_team=True),
        dependencies=[get_backdoor_likelihood_step, get_backdoor_presence_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    return get_backdoor_suspicion_score_step
