import os
import seaborn as sns
import matplotlib.pyplot as plt
from fire import Fire
import pandas as pd


def main(
    results_path: str = "results.csv",
    output_folder: str = "results",
):
    all_results = pd.read_csv(results_path)
    all_results["rewards_mean"] *= -1
    env_names = ["Ant", "HalfCheetah", "Hopper", "HumanoidStandup", "Walker"]
    plt.figure(figsize=(15, 10))
    for i, env in enumerate(env_names):
        tmp_df = all_results[(all_results['env_name'] == env)]
        plt.subplot(2, 3, i+1)
        sns.barplot(hue='algo', y='rewards_mean', data=tmp_df)
        # x labels in diagonal
        plt.xticks(rotation=45)
        plt.title(env)
        plt.tight_layout()
    plt.savefig(os.path.join(output_folder, "results.png"))

        # Assuming 'all_results' contains your data with 'rewards_mean' and 'rewards_std'
    all_results.rename(columns={"env_name": "Environment"}, inplace=True)
    all_results.rename(columns={"algo": "Method"}, inplace=True)
    # Aggregate both mean and std of rewards
    grouped_data = (
        all_results.groupby(["Environment", "Method"])
        .agg({"rewards_mean": "mean", "rewards_std": "std"})
        .reset_index()
    )



    mean_table = (
        grouped_data
        .pivot(index="Method", columns="Environment", values="rewards_mean")
        .round(0)
        .astype(int)
    )
    std_table = (
        grouped_data
        .pivot(index="Method", columns="Environment", values="rewards_std")
        .round(0)
        .astype(int)
    )

    # Find the max in each column for mean values
    max_mask = mean_table == mean_table.max()

    # Apply bold formatting to both mean and std where mean values are max
    bold_mean = mean_table.astype(str).where(
        ~max_mask, "\\textbf{" + mean_table.astype(str) + "}"
    )
    bold_std = std_table.astype(str).where(
        ~max_mask, "\\textbf{" + std_table.astype(str) + "}"
    )

    # Combine bold mean and std in the format mean ± std
    combined_table = "$" + bold_mean + " \pm " + bold_std + "$"

    latex_table = combined_table.to_latex(
        index=True,
        header=True,
        caption=f"Mean rewards for tc-adversary",
        label=f"tc-adversary",
        escape=False,
    )
    
    # Save LaTeX code to a file in the output folder
    with open(os.path.join(output_folder, "results.tex"), "w") as f:
        f.write(latex_table)


if __name__ == "__main__":
    Fire(main)
