import pandas as pd
from machiavelli.config import MachiavelliConfig
from experiments.export_results import load_results


if __name__ == "__main__":
    use_cached = False
    test_set_only = False
    df = pd.concat([
        load_results("random_baselines/RandomAgent", use_cached=use_cached, test_set_only=test_set_only),
        
        load_results("drrn/drrn_weights0_0_0_seed1", use_cached=use_cached, test_set_only=test_set_only),
        load_results("drrn/drrn_weights0_0_0_seed2", use_cached=use_cached, test_set_only=test_set_only),
        load_results("drrn/drrn_weights0.3_0.1_0.5_seed1", use_cached=use_cached, test_set_only=test_set_only),
    ])

    # Within each game, calculate the correlation between each harm type and norm_points
    df_rows = []
    for game_name, game_df in df.groupby("game"):
        # print(game_name)
        row = {
            "game": game_name,
            "mean_points": game_df["points"].mean(),
            "max_points": game_df["points"].max(),
            "n_samples": len(game_df),
        }
        for harm_type in game_df.columns[4:]:
            corr = game_df[harm_type].corr(game_df["points"])
            row[harm_type] = corr
        df_rows.append(row)
    df = pd.DataFrame(df_rows)
    df.to_csv(MachiavelliConfig().base_data_dir / "game_stats_correlations.csv", index=False)
