import json
import os

from matplotlib import pyplot as plt

from src.tools.path_utils import get_default_results_directory


def plot(
    results,
    name,
    strong_model,
    weak_model,
    prefix="",
    suffix="",
    save_fig=True,
    close=True,
):
    x = []
    y = []
    normalized_y = []
    for threshold in range(0, 11, 1):
        threshold /= 10
        eval_result = results[f"{prefix}calculate_metrics_threshold{threshold}{suffix}"]
        x.append(eval_result.safety)
        y.append(eval_result.usefulness)
        normalized_y.append(eval_result.normalized_usefulness)

    # Get the strong and weak usefulness values from the first eval_result
    strong_usefulness = results[
        f"{prefix}calculate_metrics_threshold0.0{suffix}"
    ].strong_usefulness
    weak_usefulness = results[
        f"{prefix}calculate_metrics_threshold0.0{suffix}"
    ].weak_usefulness

    # Add horizontal lines for strong and weak usefulness
    plt.axhline(
        y=strong_usefulness, color="r", linestyle="--", label="Strong Model Usefulness"
    )
    plt.axhline(
        y=weak_usefulness, color="g", linestyle="--", label="Weak Model Usefulness"
    )
    # plot normalized usefulness vs safety
    # plt.subplot(1, 2, 1)
    # plt.scatter(x, normalized_y, label=name)
    # plt.xlabel("Safety")
    # plt.ylabel("Usefulness")
    # plt.xlim([0, 1])
    # plt.ylim([0, 1])
    # plt.legend()

    # plt.subplot(1, 2, 2)
    plt.scatter(x, y, label=f"{prefix}{name}{suffix}")
    plt.xlabel("Safety")
    plt.ylabel("Denormalized Usefulness")
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.legend()

    plt.tight_layout()
    save_dir = f"{get_default_results_directory()}/{name}/merged_results/{strong_model}+{weak_model}"
    if save_fig:
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(f"{save_dir}/fig.png", dpi=1000)

        with open(f"{save_dir}/data.json", "w") as f:
            json.dump(
                {"safety": x, "usefulness": y, "normalized_usefulness": normalized_y},
                f,
                indent=2,
            )
    else:
        plt.show()

    if close:
        plt.close()
