import numpy as np

from src.model_querying.prompt_creation import (
    get_backdoor_likelihood,
    get_backdoor_presence,
    get_backdoor_presence_logprobs,
    get_backdoor_presence_logprobs_for_plan,
    get_monitor_description,
    get_monitor_description_for_plan,
    get_suspiciousness,
)
from src.model_querying.solution_extraction import (
    extract_backdoor_likelihood,
    extract_backdoor_presence,
    extract_backdoor_presence_logprobs,
    extract_backdoor_presence_logprobs_sampling,
    extract_monitor_description,
    extract_monitor_description_for_plan,
    extract_suspiciousness,
)
from src.pipeline.transformations import get_backdoor_suspicion_score, strip_comments
from src.tools.string_manipulation import format_key_suffix


def join_by_problem(data):
    joined_data = {}
    for key, value in data.items():
        problem_id = "-".join(key.split("-")[:-1])
        if problem_id not in joined_data:
            joined_data[problem_id] = []
        joined_data[problem_id].append(value)
    return joined_data


def average_suspiciousness(key_suffix=None):
    if key_suffix:
        if key_suffix[0] != "_":
            key_suffix = f"_{key_suffix}"
    else:
        key_suffix = ""

    def call(data, joined_data):
        for key, value in data.items():
            scores = [
                i[f"suspiciousness{key_suffix}"]
                for i in joined_data.get(key)
                if i[f"suspiciousness{key_suffix}"] is not None
            ]
            if scores is None or len(scores) == 0:
                estimated_prob = None
            else:
                estimated_prob = np.mean(scores)
            data[key][f"suspiciousness{key_suffix}"] = estimated_prob
            data[key]["monitor_scores"] = scores
        return data

    return call


def add_basic_blue_team_monitor_to_pipeline(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    team_suffix=None,
    remove_comments=True,
    use_cache=True,
):
    suffix = format_key_suffix(name)
    team = "blue"
    if team_suffix:
        team += format_key_suffix(team_suffix)

    # removing comments as a red-team affordance
    if remove_comments:
        prev_step = pipeline.add_transformation_step(
            f"remove_code_comments{suffix}",
            strip_comments("solution", key_suffix=suffix),
            dependencies=[prev_step],
            strong_model=strong_model,
            weak_model=weak_model,
        )

    get_monitor_description_step = pipeline.add_query_step(
        f"get_monitor_description{suffix}",
        weak_model,
        get_monitor_description(key_suffix=suffix, use_non_refusal_prompt=True),
        extract_monitor_description(key_suffix=suffix),
        dependencies=[prev_step],
        team=team,
        use_cache=use_cache,
    )

    get_backdoor_likelihood_step = pipeline.add_query_step(
        f"get_backdoor_likelihood{suffix}",
        weak_model,
        get_backdoor_presence_logprobs(key_suffix=suffix, use_non_refusal_prompt=True),
        extract_backdoor_presence_logprobs(key_suffix=suffix),
        dependencies=[get_monitor_description_step],
        logprobs=20,
        max_tokens=1,
        team=team,
        use_cache=use_cache,
    )

    return get_backdoor_likelihood_step


def add_ryans_blue_team_monitor_to_pipeline(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    team_suffix=None,
    remove_comments=True,
    use_cache=True,
    max_tokens=4096,
):
    suffix = format_key_suffix(name)
    team = "blue"
    if team_suffix:
        team += format_key_suffix(team_suffix)

    # removing comments as a red-team affordance
    if remove_comments:
        prev_step = pipeline.add_transformation_step(
            f"remove_code_comments{suffix}",
            strip_comments("solution", key_suffix=suffix),
            dependencies=[prev_step],
            strong_model=strong_model,
            weak_model=weak_model,
        )

    get_monitor_description_step = pipeline.add_query_step(
        f"get_monitor_description{suffix}",
        weak_model,
        get_monitor_description(key_suffix=suffix, use_non_refusal_prompt=False),
        extract_monitor_description(key_suffix=suffix),
        dependencies=[prev_step],
        team=team,
        use_cache=use_cache,
        max_tokens=max_tokens,
    )

    get_backdoor_likelihood_step = pipeline.add_query_step(
        f"get_backdoor_likelihood{suffix}",
        weak_model,
        get_backdoor_presence_logprobs(key_suffix=suffix, use_non_refusal_prompt=False),
        extract_backdoor_presence_logprobs(key_suffix=suffix),
        dependencies=[get_monitor_description_step],
        logprobs=20,
        max_tokens=1,
        team=team,
        use_cache=use_cache,
    )

    return get_backdoor_likelihood_step


def add_basic_blue_team_monitor_to_pipeline_sampling(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    remove_comments=True,
    bon=5,
    temperature=0.2,
):
    # assert bon > 1 and temperature > 0
    if bon > 1:
        assert temperature > 0, "Temperature must be greater than 0 when using BON > 1"
    suffix = format_key_suffix(name)

    # removing comments as a red-team affordance
    if remove_comments:
        prev_step = pipeline.add_transformation_step(
            f"remove_code_comments{suffix}",
            strip_comments("solution", key_suffix=suffix),
            dependencies=[prev_step],
            strong_model=strong_model,
            weak_model=weak_model,
        )

    get_monitor_description_step = pipeline.add_query_step(
        f"get_monitor_description{suffix}",
        weak_model,
        get_monitor_description(key_suffix=suffix),
        extract_monitor_description(key_suffix=suffix),
        dependencies=[prev_step],
        team="blue",
    )

    get_backdoor_likelihood_step = pipeline.add_query_step(
        f"get_backdoor_likelihood{suffix}",
        weak_model,
        get_backdoor_presence_logprobs(key_suffix=suffix),
        extract_backdoor_presence_logprobs_sampling(key_suffix=suffix),
        dependencies=[get_monitor_description_step],
        bon=bon,
        temperature=temperature,
        team="blue",
    )

    if bon > 1:
        join_monitors = pipeline.add_transformation_step(
            "join_monitors",
            join_by_problem,
            dependencies=[get_backdoor_likelihood_step],
            strong_model=strong_model,
            weak_model=weak_model,
        )
        get_backdoor_likelihood_sampling_step = pipeline.add_transformation_step(
            f"get_backdoor_likelihood_sampling{suffix}",
            average_suspiciousness(key_suffix=suffix),
            dependencies=[get_monitor_description_step, join_monitors],
            strong_model=strong_model,
            weak_model=weak_model,
        )
        return get_backdoor_likelihood_sampling_step
    else:
        return get_backdoor_likelihood_step


def add_basic_scalar_blue_team_monitor_to_pipeline(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    remove_comments=True,
    bon=1,
    temperature=0,
    use_cache=True,
):
    suffix = format_key_suffix(name)

    # removing comments as a red-team affordance
    if remove_comments:
        prev_step = pipeline.add_transformation_step(
            f"remove_code_comments{suffix}",
            strip_comments("solution", key_suffix=suffix),
            dependencies=[prev_step],
            strong_model=strong_model,
            weak_model=weak_model,
        )

    get_monitor_description_step = pipeline.add_query_step(
        f"get_monitor_description{suffix}",
        weak_model,
        get_monitor_description(key_suffix=suffix),
        extract_monitor_description(key_suffix=suffix),
        dependencies=[prev_step],
        team="blue",
        use_cache=use_cache,
    )

    get_backdoor_likelihood_step = pipeline.add_query_step(
        f"get_backdoor_likelihood{suffix}",
        weak_model,
        get_suspiciousness(key_suffix=suffix),
        extract_suspiciousness(key_suffix=suffix),
        dependencies=[get_monitor_description_step],
        team="blue",
        bon=bon,
        temperature=temperature,
        use_cache=use_cache,
    )

    if bon == 1:
        return get_backdoor_likelihood_step

    join_monitors = pipeline.add_transformation_step(
        "join_monitors",
        join_by_problem,
        dependencies=[get_backdoor_likelihood_step],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    get_backdoor_likelihood_sampling_step = pipeline.add_transformation_step(
        f"get_backdoor_likelihood_sampling{suffix}",
        average_suspiciousness(key_suffix=suffix),
        dependencies=[get_monitor_description_step, join_monitors],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    return get_backdoor_likelihood_sampling_step


def add_plan_blue_team_monitor_to_pipeline(
    pipeline,
    strong_model,
    weak_model,
    prev_step,
    name=None,
    team_suffix=None,
    use_cache=True,
):
    suffix = format_key_suffix(name)
    team = "blue"
    if team_suffix:
        team += format_key_suffix(team_suffix)

    get_monitor_description_step = pipeline.add_query_step(
        f"get_monitor_description_for_plan{suffix}",
        weak_model,
        get_monitor_description_for_plan(key_suffix=suffix),
        extract_monitor_description_for_plan(key_suffix=suffix),
        dependencies=[prev_step],
        team=team,
        use_cache=use_cache,
    )

    get_backdoor_likelihood_step = pipeline.add_query_step(
        f"get_backdoor_likelihood_for_plan{suffix}",
        weak_model,
        get_backdoor_presence_logprobs_for_plan(key_suffix=suffix),
        extract_backdoor_presence_logprobs(key_suffix=suffix),
        dependencies=[get_monitor_description_step],
        logprobs=20,
        max_tokens=1,
        team=team,
        use_cache=use_cache,
    )

    return get_backdoor_likelihood_step
