##
## (c) Anonymous authors (2026)
##
## > Script to preprocess results exported from Weights & Biases (wandb)
##
##

import os
from functools import reduce

import numpy as np
import pandas as pd
import seaborn as sns


def preprocess_raw_data(raw_data_path: str) -> pd.DataFrame:
    """Preprocessing results"""

    # Load .csv file with raw data
    df = pd.read_csv(raw_data_path)

    # Store simulation timestep separately
    df['simulation_timesteps'] = df['simulation_timesteps']

    # Identify model names (prefixes)
    models = set(
        col.split(" - ")[0] for col in df.columns if
        "performance/avg100_behavior_mean_return" in col and "__" not in col)

    # Collect per-model DataFrames with step, return, and timestep
    model_frames = []

    for model in models:
        step_col = f"{model} - _step"
        return_col = f"{model} - performance/avg100_behavior_mean_return"

        if step_col in df.columns and return_col in df.columns:
            temp = df[[step_col, return_col, "simulation_timesteps"]].dropna()
            temp = temp.rename(columns={
                step_col: "step",
                return_col: f"{model}_return",
                "simulation_timesteps": f"{model}_timestep"
            })
            temp = temp.set_index("step")
            model_frames.append(temp)

    # Merge all model-specific DataFrames on the step index
    result_df = reduce(lambda left, right: pd.merge(left, right, how='outer', left_index=True, right_index=True),
                       model_frames)

    # Sort by step
    result_df = result_df.sort_index()

    return result_df


def enrich_results_with_metadata(raw_df: pd.DataFrame, metadata_path: str, n_training_steps: int) -> pd.DataFrame:
    """Encriching results dataframe with run metadata from .csv file"""

    # Load .csv file with metadata
    meta_df = pd.read_csv(metadata_path)

    name_to_algo = dict(zip(meta_df["Name"], meta_df["algo"]))
    name_to_seed = dict(zip(meta_df["Name"], meta_df["seed"]))

    def rename_column(col):
        """Replacing column names consisting of run ids with actual algorithm names"""
        for name in name_to_algo:
            if col.startswith(name):
                suffix = col.replace(name, "")
                return f"{name_to_algo[name]}_{name_to_seed[name]}{suffix}"
        return col

    # Renaming columns
    raw_df.columns = [rename_column(col) for col in raw_df.columns]
    raw_df['env'] = meta_df.loc[0, "env"]

    # Extracting all timestep columns
    timestep_cols = [col for col in raw_df.columns if col.endswith("_timestep")]
    # Compute row-wise mean of timestep values, ignoring NaNs
    mean_timesteps = raw_df[timestep_cols].mean(axis=1).round().astype(int)

    # Building long-form DataFrame
    long_data = []

    for col in raw_df.columns:
        if col.endswith("_return"):
            parts = col.split("_")
            algo = parts[0]
            seed = parts[1]
            temp = pd.DataFrame({
                "algorithm": algo,
                "seed": seed,
                "return": raw_df[col],
                "timestep": mean_timesteps,
                "env": raw_df["env"]
            }).dropna()
            long_data.append(temp)

    # Combining and filtering
    result_df = pd.concat(long_data, ignore_index=True)
    result_df = result_df[result_df["timestep"] <= n_training_steps - 250_000]

    return result_df


if __name__ == "__main__":

    envs = ["heavenhell-3", "shopping-5", "car-flag", "cleaner", "7x7-4rooms", "9x9-4rooms"]

    n_training_steps = {
        "heavenhell-3": 6_250_000,
        "shopping-5": 2_250_000,
        "car-flag": 4_250_000,
        "cleaner": 4_250_000,
        "7x7-4rooms": 6_250_000,
        "9x9-4rooms": 6_250_000
    }

    input_dir = "results/wandb"
    output_dir = "results/preprocessed"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for env in envs:
        # Preprocessing of raw results
        df_raw_preprocessed = preprocess_raw_data(
            raw_data_path=f"{input_dir}/raw/{env}_{n_training_steps[env]}_final_0-19_raw.csv")

        # Enriching preprocessed results with run metadata
        df_enriched = enrich_results_with_metadata(raw_df=df_raw_preprocessed,
                                                   metadata_path=f"{input_dir}/metadata/{env}_{n_training_steps[env]}_final_0-19_metadata.csv",
                                                   n_training_steps=n_training_steps[env])

        # Saving preprocessed results to .csv file
        splitted_env_name = df_enriched.loc[0, "env"].split("/")[-1].split(".")
        if len(splitted_env_name) > 1:
            env_name = splitted_env_name[0] + "-" + splitted_env_name[1]
        else:
            env_name = splitted_env_name[0]
        df_enriched.to_csv(f"{output_dir}/{env_name}_final_0-19.csv", sep=",")
