import seaborn as sns
import matplotlib.pyplot as plt
from fire import Fire
import pandas as pd
import os


def main(
    results_path: str = "results.csv",
    output_folder: str = "results",
):
    all_results = pd.read_csv(results_path)
    env_names = ["Ant", "HalfCheetah", "Hopper", "HumanoidStandup", "Walker"]
    plt.figure(figsize=(15, 10))
    for i, env in enumerate(env_names):
        tmp_df = all_results[(all_results["env_name"] == env)]
        plt.subplot(2, 3, i + 1)
        sns.barplot(x="scheduler_type", y="rewards_mean", data=tmp_df, hue="algo")
        # x labels in diagonal
        plt.xticks(rotation=45)
        plt.title(env)
        plt.tight_layout()

    plt.savefig(os.path.join(output_folder, "results.png"))

    # Assuming 'all_results' contains your data with 'rewards_mean' and 'rewards_std'
    all_results.rename(columns={"env_name": "Environment"}, inplace=True)
    all_results.rename(columns={"algo": "Method"}, inplace=True)
    # Aggregate both mean and std of rewards
    grouped_data = (
        all_results.groupby(["Environment", "scheduler_type", "Method"])
        .agg({"rewards_mean": "mean", "rewards_std": "std"})
        .reset_index()
    )

    scheduler_types = grouped_data["scheduler_type"].unique()

    tables = {}
    for scheduler in scheduler_types:
        # Create a pivot table for mean and std separately
        mean_table = (
            grouped_data[grouped_data["scheduler_type"] == scheduler]
            .pivot(index="Method", columns="Environment", values="rewards_mean")
            .round(0)
            .astype(int)
        )
        std_table = (
            grouped_data[grouped_data["scheduler_type"] == scheduler]
            .pivot(index="Method", columns="Environment", values="rewards_std")
            .round(0)
            .astype(int)
        )

        # Find the max in each column for mean values
        max_mask = mean_table == mean_table.max()

        # Apply bold formatting to both mean and std where mean values are max
        bold_mean = mean_table.astype(str).where(
            ~max_mask, "\\textbf{" + mean_table.astype(str) + "}"
        )
        bold_std = std_table.astype(str).where(
            ~max_mask, "\\textbf{" + std_table.astype(str) + "}"
        )

        # Combine bold mean and std in the format mean ± std
        combined_table = "$" + bold_mean + " \pm " + bold_std + "$"

        tables[scheduler] = combined_table

    latex_tables = {}
    for scheduler, df in tables.items():
        # Generate LaTeX code for each table
        latex_tables[scheduler] = df.to_latex(
            index=True,
            header=True,
            caption=f"Rewards Mean for Scheduler: {scheduler}",
            label=f"tab:{scheduler}",
            escape=False,
        )
        # Save LaTeX code to a file in the output folder
        with open(f"{output_folder}/table_{scheduler}.tex", "w") as f:
            f.write(latex_tables[scheduler])


if __name__ == "__main__":
    Fire(main)
