import matplotlib.pyplot as plt


def plot_equations():
    # Data for each model
    data = {
        "GPT-4o": [500, 500, 500, 500, 494, 476, 414, 305, 259, 163, 142, 129],
        "Claude 3.5 Sonnet": [
            500,
            500,
            500,
            499,
            477,
            364,
            260,
            207,
            151,
            136,
            124,
            132,
        ],
        "Gemini 1.5 Pro": [500, 500, 500, 498, 481, 395, 318, 223, 189, 146, 136, 104],
        "Llama 3.1 405b": [500, 500, 498, 496, 455, 337, 226, 171, 158, 149, 134, 134],
    }
    markers = {
        "GPT-4o": "o",
        "Claude 3.5 Sonnet": "s",
        "Gemini 1.5 Pro": "D",
        "Llama 3.1 405b": "^",
    }

    # Divide by 500
    for model, accuracies in data.items():
        data[model] = [acc / 500 * 100 for acc in accuracies]

    # Corresponding values of D
    D_values = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

    # Plotting the data
    plt.figure(figsize=(5.8, 4))

    for model, accuracies in data.items():
        plt.plot(D_values, accuracies, marker=markers[model], label=model)

    # Adding base accuracy line
    plt.axhline(y=25, color="gray", linestyle="--")

    # Label on the line itself, above the dashed line, near the vertical axis
    plt.text(
        2.1,
        26,
        "Random guessing",
        verticalalignment="bottom",
        horizontalalignment="left",
        color="gray",
    )

    # Adding titles and labels
    plt.title("Accuracy of state-of-the-art language models\non the retrieval task ($\\it{equations}$ formulation)")
    plt.xlabel("Number of retrival steps (D)")
    plt.ylabel("Accuracy (%)")
    plt.xticks(D_values)
    plt.grid(True)
    plt.legend()

    # Remove left and right padding
    plt.tight_layout()

    # Save the plot
    plt.savefig("llm_retrieval_equations.png")
    plt.savefig("llm_retrieval_equations.svg")
    plt.savefig("llm_retrieval_equations.pdf")


def plot_others():
    # Data for the new plot
    categories = [
        "$\\it{“Lives\\ with”}$\nformulation\n(D=5)",
        "$\\it{“Kingdoms”}$\nformulation\n(D=5)",
        "$\\it{“Functions”}$\nformulation",
        "$\\it{“Relatives”}$\nformulation",
    ]

    accuracy_values = [95, 89, 91, 67]
    base_rates = [25, 25, 25, 6.25]

    # Plotting the data with horizontal pastel green bars and corrected base lines
    plt.figure(figsize=(5, 4))

    # Plotting the horizontal bars with reversed order
    bars = plt.barh(categories[::-1], accuracy_values[::-1], color="palegreen")

    # Adding the base rate lines and labels
    for i, base_rate in enumerate(base_rates[::-1]):
        plt.axvline(
            x=base_rate,
            color="gray",
            linestyle="--",
            ymax=(i + 1) / len(categories),
            ymin=i / len(categories),
        )

    # Adding titles and labels
    plt.title("Accuracy of GPT-4o on different $\\it{retrieval}$\nand $\\it{conditional\\ retrieval}$ tasks")
    plt.xlabel("Accuracy (%)")
    plt.xlim(0, 100)

    # Centering the labels on the y-axis
    plt.yticks(plt.yticks()[0], ha="center")
    plt.gca().yaxis.set_tick_params(pad=35)

    # Remove left and right padding
    plt.tight_layout()

    # Save the plot
    plt.savefig("llm_retrieval_tasks.png")
    plt.savefig("llm_retrieval_tasks.svg")
    plt.savefig("llm_retrieval_tasks.pdf")


plot_equations()
plot_others()
