import asyncio
import json
import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score, roc_curve

from src.pipeline.blue_team.monitoring import (
    add_basic_blue_team_monitor_to_pipeline,
    add_basic_blue_team_monitor_to_pipeline_sampling,
    add_basic_scalar_blue_team_monitor_to_pipeline,
)
from src.pipeline.pipeline import Pipeline, PipelineConfig
from src.pipeline.red_team.monitoring import (
    add_red_team_scalar_sneakiness_steps_to_pipeline,
    add_red_team_sneakiness_logprob_steps_to_pipeline,
    add_red_team_sneakiness_steps_to_pipeline,
)
from src.tools.dataloaders import load_multiple_solutions, load_solutions
from src.tools.path_utils import get_default_results_directory


def get_pipeline(
    strong_model,
    weak_model,
    num_problems=None,
    name=None,
    logprob=True,
    bon=1,
    temperature=0,
):
    config = PipelineConfig(
        name,
        anthropic_num_threads=40,
        openai_fraction_rate_limit=0.99,
        num_problems=num_problems,
        use_cache=True,
    )

    pipeline = Pipeline(config)

    assert os.path.exists(
        get_default_results_directory()
        + "/trusted_monitoring-newredteam/38-final_backdoors"
    )
    get_backdoored_solutions = pipeline.add_load_data_step(
        "get_backdoored_solutions_1",
        load_multiple_solutions,
        get_default_results_directory()
        / "trusted_monitoring-newredteam"
        / "38-final_backdoors"
        / f"{strong_model}+gpt-3.5-turbo-1106",
    )

    def merge(*args):
        merged = {}
        for x in args:
            for key in x.keys():
                for value in x[key]:
                    if not value.get("metadata"):
                        value["metadata"] = {
                            k: v for k, v in value.items() if k != "metadata"
                        }
                if key in merged.keys():
                    merged[key].extend(x[key])
                else:
                    merged[key] = x[key]
        return {
            f"{key}-{i}": value
            for key, value_list in merged.items()
            for (i, value) in enumerate(value_list)
        }

    merged_backdoors = pipeline.add_transformation_step(
        "merged_backdoors",
        merge,
        dependencies=[get_backdoored_solutions],
        strong_model=strong_model,
        weak_model=weak_model,
    )

    if "claude" in weak_model:
        get_suspiciousness_backdoors = add_basic_blue_team_monitor_to_pipeline_sampling(
            pipeline,
            strong_model,
            weak_model,
            merged_backdoors,
            name="backdoors",
            bon=bon,
            temperature=temperature,
        )
    else:
        get_suspiciousness_backdoors = add_basic_blue_team_monitor_to_pipeline(
            pipeline, strong_model, weak_model, merged_backdoors, name="backdoors"
        )

    if not logprob:  # get scalar (1-10)
        get_sneakiness_backdoors = add_red_team_scalar_sneakiness_steps_to_pipeline(
            pipeline,
            strong_model,
            weak_model,
            merged_backdoors,
            name="backdoors",
            bon=bon,
            temperature=temperature,
        )
    else:
        get_sneakiness_backdoors = add_red_team_sneakiness_logprob_steps_to_pipeline(
            pipeline, strong_model, weak_model, merged_backdoors, name="backdoors"
        )

    return pipeline, get_suspiciousness_backdoors.name, get_sneakiness_backdoors.name


def plot_results(
    backdoored_scores,
    backdoored_sneakiness,
    name,
    strong_model,
    weak_model,
    logprob=True,
):
    save_dir = f"{get_default_results_directory()}/{name}/merged_results/{strong_model}+{weak_model}"
    os.makedirs(save_dir, exist_ok=True)

    if not logprob:
        backdoored_scores = [i / 10 for i in backdoored_scores]

    correlation, p_value = pearsonr(backdoored_scores, backdoored_sneakiness)
    print(f"Correlation: {correlation:.2f}, p-value: {p_value:.4f}")

    # Calculate mean and standard deviation
    backdoored_mean = np.mean(backdoored_scores)
    backdoored_std = np.std(backdoored_scores)

    backdoored_sneakiness_mean, backdoored_sneakiness_std = np.mean(
        backdoored_sneakiness
    ), np.std(backdoored_sneakiness)

    print(
        f"Backdoored solutions - Mean: {backdoored_mean:.2f}, Std Dev: {backdoored_std:.2f}, Mean Sneakiness: {backdoored_sneakiness_mean:.2f}, Std Dev: {backdoored_sneakiness_std:.2f}"
    )

    plt.figure(figsize=(10, 6))
    plt.scatter(backdoored_sneakiness, backdoored_scores, alpha=0.5)
    plt.xlabel("Sneakiness Scores")
    plt.ylabel("Suspiciousness Scores")
    plt.title(f"Suspiciousness vs Sneakiness\nCorrelation: {correlation:.2f}")

    z = np.polyfit(backdoored_sneakiness, backdoored_scores, 1)
    p = np.poly1d(z)
    plt.plot(backdoored_sneakiness, p(backdoored_sneakiness), "r--", alpha=0.8)

    plt.savefig(f"{save_dir}/{name}.png", dpi=1000)
    plt.close()


if __name__ == "__main__":
    bon = 1
    temp = 0

    name = f"sneakiness_ablation-gpt-3.5-turbo-logprob-bon{bon}-temperature{temp}"
    strong_model = "gpt-4-0125-preview"
    weak_model = "gpt-3.5-turbo-1106"

    (
        pipeline,
        get_suspiciousness_backdoors_step,
        get_sneakiness_backdoors_step,
    ) = get_pipeline(
        strong_model=strong_model,
        weak_model=weak_model,
        num_problems=200,
        name=name,
        logprob=True,
        bon=bon,
        temperature=temp,
    )

    results = asyncio.run(pipeline.run())
    backdoored_scores = [
        result.get("suspiciousness_backdoors", 1.0)
        for result in results[get_suspiciousness_backdoors_step].values()
    ]
    backdoored_sneakiness = [
        result.get("sneakiness")
        for result in results[get_sneakiness_backdoors_step].values()
    ]

    plot_results(
        backdoored_scores,
        backdoored_sneakiness,
        pipeline.config.name,
        strong_model,
        weak_model,
    )
