import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def plot_sascore():
    input_path = "visualizations/raw_data/synthesizability_score.csv"
    output_path = "visualizations/plots/synthesizability_score.png"

    df = pd.read_csv(input_path, index_col=0)

    datasets = df.index.tolist()
    whislo = df["p5"].tolist()
    q1s = df["q1"].tolist()
    medians = df["median"].tolist()
    q3s = df["q3"].tolist()
    whishi = df["p95"].tolist()

    fig, ax = plt.subplots(figsize=(10, 6))
    box_data = [
        {
            "whislo": whislo[i],
            "q1": q1s[i],
            "med": medians[i],
            "q3": q3s[i],
            "whishi": whishi[i],
            "fliers": [],
        }
        for i in range(len(datasets))
    ]

    # prepare bar plot positions
    # we want extra space for "SuperNatural3", which has index 5
    base_spacing = 0.4
    extra_gap = 0.1

    positions = []
    for i in range(len(datasets)):
        if i < 5:
            positions.append(i * base_spacing)
        elif i == 5:
            positions.append(i * base_spacing + extra_gap)
        else:
            positions.append(i * base_spacing + 2 * extra_gap)

    colors = plt.cm.viridis(np.linspace(0, 1, len(datasets)))
    for pos, stats, color in zip(positions, box_data, colors, strict=False):
        ax.bxp(
            [stats],
            patch_artist=True,
            positions=[pos],
            widths=0.3,
            boxprops=dict(facecolor=color, edgecolor="black", linewidth=1.5),
            medianprops=dict(color="black", linewidth=2),
            whiskerprops=dict(color="black", linewidth=1.5),
            capprops=dict(color="black", linewidth=1.5),
            flierprops=dict(
                marker="o",
                markersize=5,
                markerfacecolor="black",
                markeredgecolor="black",
            ),
        )

    ax.set_xticks(positions)
    ax.set_xticklabels(datasets, fontsize=10)
    ax.tick_params(axis="y", labelsize=10)
    ax.grid(True, axis="y", linestyle="--", alpha=0.7)

    # tighter before first and after last bar
    ax.set_xlim(min(positions) - 0.25, max(positions) + 0.25)

    plt.tight_layout()

    plt.savefig(
        output_path,
        dpi=300,
        bbox_inches="tight",
        facecolor="white",
        edgecolor="none",
    )


if __name__ == "__main__":
    plot_sascore()
