import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots
from matplotlib.ticker import FuncFormatter
from scipy.stats import gaussian_kde


def clean_output(input, completion):
    match = False
    for i in [
        " 20",
        " 21",
        " 22",
        " 23",
        " 24",
        " 25",
        " 26",
        " 27",
        " 28",
        " 29",
        " 30",
        " 31",
        " 32",
        " 34",
        " 40",
        " 19",
        " 18",
        " 17",
        " 16",
        " 15",
        " 14",
        " 30",
        " 12",
    ]:
        if completion[:2] == i[1:]:
            try:
                clean = int(completion[:4])
                match = True
                return clean
            except:
                clean = ""
        elif str.find(completion, i) > -1:
            try:
                clean = int(
                    completion[
                        str.find(completion, i) + 1 : str.find(completion, i) + 5
                    ]
                )

                match = True
                return clean
            except:
                continue
                # return ''

    if not match:
        clean = ""
        print(input, ":", completion)
        return clean


def format_sigfigs(x, pos):
    return f"{x:.2f}"


def get_agg_stats(df, topic, bin_types=["mean", "median"]):
    df["model"] = np.where(
        df["model"] == "gpt_3.5",
        "GPT3.5",
        np.where(
            df["model"] == "gpt_4",
            "GPT4",
            np.where(
                df["model"] == "LLAMA2_7B",
                "Llama2-7B",
                np.where(
                    df["model"] == "LLAMA2_13B",
                    "Llama2-13B",
                    np.where(df["model"] == "LLAMA2_70B", "Llama2-70B", df["model"]),
                ),
            ),
        ),
    )
    df = df[df["year"] != ""]
    df["year"] = df["year"].astype(int)
    df["over_cutoff"] = np.where(df.year > 2023, 1, 0)

    agg_df = (
        df.groupby([topic, "model"])
        .agg({"over_cutoff": "sum", "year": ["mean", "median", "std", "min", "max"]})
        .reset_index()
    )
    agg_df["over_cutoff_perc"] = agg_df["over_cutoff"] / 10
    agg_df.columns = [
        topic,
        "model",
        "over_cutoff",
        "mean_year",
        "median_year",
        "std_year",
        "min_year",
        "max_year",
        "over_cutoff_perc",
    ]

    binned_df = {}
    for b in bin_types:
        agg_df[f"{b}_year_bin"] = agg_df[f"{b}_year"].round(1).astype(int)

        binned_df[b] = (
            agg_df[
                [topic, "model", f"{b}_year_bin", "std_year", "min_year", "max_year"]
            ]
            .groupby(["model", f"{b}_year_bin"])
            .agg(
                {
                    topic: "count",
                    "std_year": ["mean", "median"],
                    "min_year": ["mean", "median"],
                    "max_year": ["mean", "median"],
                }
            )
            .reset_index()
        )
        binned_df[b].columns = [
            "model",
            f"{b}_year_bin",
            "count_in_bin",
            "mean_std",
            "median_std",
            "mean_min",
            "median_min",
            "mean_max",
            "median_max",
        ]

        for col in [
            f"{b}_year_bin",
            "count_in_bin",
            "mean_min",
            "median_min",
            "mean_max",
            "median_max",
        ]:
            binned_df[b][col] = binned_df[b][col].astype(int)

        binned_df[b]["perc_in_bin"] = binned_df[b]["count_in_bin"] / 100

    return agg_df, binned_df


def plot_pdf(agg_df, avg_metric, title, xmin, xmax):
    models = ["Llama2-7B", "Llama2-13B", "Llama2-70B", "GPT3.5", "GPT4"]

    plt.style.use(["science", "bright"])

    fig, ax = plt.subplots(figsize=(8, 6))

    colors = ["#BBBBBB", "#66CCEE", "#4477AA", "#CCBB44", "#228833"]
    x_range = np.linspace(xmin, xmax, int(xmax - xmin))
    for i, model in enumerate(models):
        data = agg_df[agg_df["model"] == model]

        kde = gaussian_kde(data[f"{avg_metric}_year"])
        pdf_smooth = kde.evaluate(x_range)

        ax.plot(
            x_range,
            pdf_smooth,
            linewidth=2,
            color=colors[i],
            label=f"{model} ($\mu$={data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={data[f'{avg_metric}_year'].std():.0f})",
        )
        ax.legend()
    plt.xlim = (xmin, xmax)
    plt.title(title)
    plt.tight_layout()
    plt.show()


def plot_events_pdf(
    disc_agg_df,
    events_agg_df,
    models,
    colors,
    avg_metric,
    title,
    xmin,
    xmax,
    output_file,
):
    plt.style.use(["science", "bright"])

    fig, ax = plt.subplots(1, len(models), figsize=(14, 5))

    colors = ["#BBBBBB", "#66CCEE", "#4477AA", "#CCBB44", "#228833"]
    x_range = np.linspace(xmin, xmax, int(xmax - xmin))
    max_y = 0
    for i, model in enumerate(models):
        disc_data = disc_agg_df[disc_agg_df["model"] == model]

        disc_kde = gaussian_kde(disc_data[f"{avg_metric}_year"])
        disc_pdf_smooth = disc_kde.evaluate(x_range)

        max_y = max(max_y, disc_pdf_smooth.max())

        ax[i].plot(
            x_range,
            disc_pdf_smooth,
            linewidth=2,
            color=colors[2],
            label="Future Discoveries",
        )  # ($\mu$={disc_data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={disc_data[f'{avg_metric}_year'].std():.0f})")

        events_data = events_agg_df[events_agg_df["model"] == model]

        events_kde = gaussian_kde(events_data[f"{avg_metric}_year"])
        events_pdf_smooth = events_kde.evaluate(x_range)
        max_y = max(max_y, disc_pdf_smooth.max())

        ax[i].plot(
            x_range,
            events_pdf_smooth,
            linewidth=2,
            color=colors[0],
            label="Outlandish Events",
        )  # ($\mu$={fic_data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={fic_data[f'{avg_metric}_year'].std():.0f})")

        ax[i].axvline(
            x=2023, color="#EE6677", linestyle="--", linewidth=1, label="Future Cutoff"
        )

        ax[i].set_title(model, fontsize=16)
        ax[i].set_xlim(xmin, xmax)
        if i == 2:
            ax[i].legend(fontsize=11, loc="upper right")

        ax[i].tick_params(axis="x", labelsize=12)  # Change the font size for the x-axis
        ax[i].tick_params(axis="y", labelsize=12)

        # Apply formatter to the y-axis
        ax[i].yaxis.set_major_formatter(FuncFormatter(format_sigfigs))

    for i in range(0, len(models)):
        ax[i].set_ylim(0, max_y + 0.01)
    fig.suptitle(title, fontsize=20)
    plt.tight_layout()
    ax[1].set_xlabel("Year", fontsize=13)
    ax[0].set_ylabel("Density", fontsize=13)
    plt.savefig(output_file)
    plt.show()


def plot_pres_predictions(
    pol_agg_df,
    fic_agg_df,
    gen_agg_df,
    models,
    colors,
    avg_metric,
    title,
    xmin,
    xmax,
    output_file,
):
    plt.style.use(["science", "bright"])

    fig, ax = plt.subplots(1, len(models), figsize=(14, 5))

    colors = ["#BBBBBB", "#66CCEE", "#4477AA", "#CCBB44", "#228833"]
    x_range = np.linspace(xmin, xmax, int(xmax - xmin))
    max_y = 0
    for i, model in enumerate(models):
        pol_data = pol_agg_df[pol_agg_df["model"] == model]

        pol_kde = gaussian_kde(pol_data[f"{avg_metric}_year"])
        pol_pdf_smooth = pol_kde.evaluate(x_range)

        max_y = max(max_y, pol_pdf_smooth.max())

        ax[i].plot(
            x_range,
            pol_pdf_smooth,
            linewidth=2,
            color=colors[2],
            label="Current Politicians",
        )  # ($\mu$={pol_data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={pol_data[f'{avg_metric}_year'].std():.0f})")

        fic_data = fic_agg_df[fic_agg_df["model"] == model]

        fic_kde = gaussian_kde(fic_data[f"{avg_metric}_year"])
        fic_pdf_smooth = fic_kde.evaluate(x_range)
        max_y = max(max_y, pol_pdf_smooth.max())

        ax[i].plot(
            x_range,
            fic_pdf_smooth,
            linewidth=2,
            color=colors[0],
            label="Fictional Characters",
        )  # ($\mu$={fic_data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={fic_data[f'{avg_metric}_year'].std():.0f})")

        gen_data = gen_agg_df[gen_agg_df["model"] == model]

        gen_kde = gaussian_kde(gen_data[f"{avg_metric}_year"])
        gen_pdf_smooth = gen_kde.evaluate(x_range)
        max_y = max(max_y, pol_pdf_smooth.max())

        ax[i].plot(
            x_range,
            gen_pdf_smooth,
            linewidth=2,
            color=colors[1],
            label="Generic Names",
        )  # ($\mu$={gen_data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={gen_data[f'{avg_metric}_year'].std():.0f})")

        ax[i].axvline(
            x=2023, color="#EE6677", linestyle="--", linewidth=1, label="Future Cutoff"
        )

        ax[i].set_title(model, fontsize=16)
        ax[i].set_xlim(xmin, xmax)
        if i == 2:
            ax[i].legend(fontsize=11, loc="upper right")

        ax[i].tick_params(axis="x", labelsize=12)  # Change the font size for the x-axis
        ax[i].tick_params(axis="y", labelsize=12)

        # Apply formatter to the y-axis
        ax[i].yaxis.set_major_formatter(FuncFormatter(format_sigfigs))

    for i in range(0, len(models)):
        ax[i].set_ylim(0, max_y + 0.01)
    fig.suptitle(title, fontsize=20)
    plt.tight_layout()
    ax[1].set_xlabel("Year", fontsize=13)
    ax[0].set_ylabel("Density", fontsize=13)
    plt.savefig(output_file)
    plt.show()
