# -*- coding: utf-8 -*-
import os
import pandas as pd

folder_path = "./runs/KANs"

for file_name in os.listdir(folder_path):
    if file_name.endswith("_KANs_results.csv"):
        file_path = os.path.join(folder_path, file_name)
        df = pd.read_csv(file_path)

        df["episodic_length"] = df["episodic_length"].str.strip("[]").astype(int)
        df["episodic_return"] = df["episodic_return"].str.strip("[]").astype(float)

        global_steps = df["global_step"].unique()
        wide_df = pd.DataFrame({"global_step": global_steps})

        for seed in df["seed"].unique():
            seed_data = df[df["seed"] == seed][["global_step", "episodic_length", "episodic_return"]]
            seed_data = seed_data.rename(
                columns={
                    "episodic_length": f"episodic_length (seed {seed})",
                    "episodic_return": f"episodic_return (seed {seed})",
                }
            )
            wide_df = pd.merge(wide_df, seed_data, on="global_step", how="left")

        return_columns = [col for col in wide_df.columns if "episodic_return" in col]
        wide_df["episodic_return_mean"] = wide_df[return_columns].mean(axis=1)

        output_file_name = file_name.replace("_KANs_results.csv", "_KANs_results_wide.csv")
        output_file_path = os.path.join(folder_path, output_file_name)
        wide_df.to_csv(output_file_path, index=False)

        print(f"{file_name} processed. Output saved to {output_file_name}")