from glob import glob
import pandas as pd

subsets = ["test_unfiltered", "valid_medium", "hard"]
pretty_subsets = ["Unfil", "Med", "Hard"]
peformance_csvs = glob("data/*.csv")
stt = 6
seed_idx = 3
performance_dict = {}
pretty_map = {"drc_33": "DRC(3, 3)", "drc_11": "DRC(1, 1)", "resnet": "ResNet"}
print("Processing CSVs...")
for csv in peformance_csvs:
    try:
        df = pd.read_csv(csv)

        col_name = pretty_map[csv.split("/")[1].split(".")[0]]
        performance_dict[col_name] = {}
        for pretty_name, subset in zip(pretty_subsets, subsets):
            baseline_col = f"{subset}/00_episode_successes_{seed_idx}"
            stt_col = f"{subset}/{stt:02d}_episode_successes_{seed_idx}"
            baseline_performance = df[baseline_col].iloc[-1]
            stt_performance = df[stt_col].iloc[-1]

            performance_dict[col_name][pretty_name + "_No Thinking"] = baseline_performance * 100
            performance_dict[col_name][pretty_name + "_Thinking"] = stt_performance * 100

    except FileNotFoundError:
        print(f"Warning: File not found {csv}. Skipping.")
    except pd.errors.EmptyDataError:
        print(f"Warning: File {csv} is empty. Skipping.")
    except Exception as e:
        print(f"Warning: Error processing {csv}: {e}. Skipping.")

performance_df = pd.DataFrame(performance_dict).T
print("DataFrame created. Rows:", len(performance_df))

original_columns = performance_df.columns
multi_index_tuples = []

for i, col in enumerate(original_columns):
    parts = col.rsplit("_", 1)
    if len(parts) == 2:
        subset_name, type_name = parts
        generated_tuple = (type_name, subset_name)
        multi_index_tuples.append(generated_tuple)
    else:
        print(f"Warning: Column name '{col}' did not split as expected.")
        multi_index_tuples.append(("unknown", col))

performance_df.columns = pd.MultiIndex.from_tuples(multi_index_tuples, names=["Model", None])

performance_df = performance_df.sort_index(axis=1, level=[0, 1])

try:
    performance_df = performance_df.sort_values(by=("thinking", "test_unfiltered"), ascending=False)
    print("Rows sorted by ('thinking', 'test_unfiltered').")
except KeyError:
    print("Warning: Could not sort rows by ('thinking', 'test_unfiltered') - column might be missing?")

performance_df.index.name = None
for i, col in enumerate(performance_df.columns):
    if col[0] == "Thinking":
        performance_df.loc["ResNet", col] = "-"


performance_df = performance_df.reindex(list(pretty_map.values()), fill_value=0)  # type: ignore


latex_string = performance_df.to_latex(
    float_format="%.1f",
    escape=False,
    multicolumn=True,
    multirow=True,
    multicolumn_format="c",
    header=True,
    index=True,
    column_format="lrrrrrr",
    label="tab:thinking_performance",
    caption=f"Solve rate (\%) of different models without and with {stt} thinking steps on held out sets of varying difficulty.",
)
print(latex_string)
