import os
import matplotlib.pyplot as plt
import pandas as pd
from lib.result_manager import ResultManager
from lib.constants import FIGURES_DIR


def draw_by_df(df: pd.DataFrame) -> plt.Figure:
    """Draw the convergence of the Isolation Forest on a given dataset.

    Args:
        df (pd.DataFrame): The dataframe containing the results.

    Returns:
        fig: The figure.
    """
    fig = plt.figure(figsize=(5, 3.5))
    ax = fig.add_subplot(1, 1, 1)
    features = df["feature"].unique()
    for feature in features:
        df_feature = df.query(f"feature == '{feature}'")
        n_trees = df_feature["n_trees"].unique().tolist()
        means = df_feature.groupby("n_trees")["mse"].mean().values
        stds = df_feature.groupby("n_trees")["mse"].std().values

        ax.plot(n_trees, means, label=feature)
        ax.fill_between(
            n_trees,
            means - 2 * stds,
            means + 2 * stds,
            alpha=0.2,
            label='_nolegend_'
        )

    plt.xlabel("Number of Trees")
    plt.ylabel("Mean Squared Error")
    # plt.title(f"Convergence of Isolation Forest on {df['dataset'].unique()[0]}")
    plt.legend(features)
    plt.tight_layout()
    return fig


def main():
    result_manager = ResultManager()
    results_df = result_manager.get_results()
    os.makedirs(FIGURES_DIR, exist_ok=True)

    for dataset in results_df["dataset"].unique():
        df = results_df.query(f"dataset == '{dataset}'").sort_values(by="n_trees")
        fig = draw_by_df(df)
        fig.savefig(os.path.join(FIGURES_DIR, f"convergence_{dataset}.pdf"))
        plt.close(fig)


if __name__ == "__main__":
    main()
