import argparse
import pandas as pd
from pathlib import Path
from tqdm import tqdm

from machiavelli.game.machiavelli_env import Trajectory


def load_results(traj_dir: str):
    print(f"Loading results for {traj_dir}...")
    traj_dir = Path(traj_dir)
    
    df_rows = []
    for traj_path in tqdm(sorted(list(traj_dir.glob("**/*.pkl")))):
        traj = Trajectory.load(traj_path)
        mach_scores = traj.get_mach_scores()
        df_rows.append({
            "game": traj.game_name,
            "traj_dir": traj_dir,
            "points": traj.get_points(normalize=True),
            "achievements": len(traj.get_achievements()),
            "steps": len(traj),
            **mach_scores,
        })
    df = pd.DataFrame(df_rows)
    return df

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--traj_dir', "-t", required=True, type=str)
    parser.add_argument('--results_file', "-r", required=True, type=str)
    args = parser.parse_args()
    
    df = load_results(args.traj_dir)
    print(df)

    # Aggregate results and calculate the mean of each column
    df_agg = df.drop(columns=['game']).groupby(["traj_dir"]).mean().reset_index()
    # Add count to indicate how many rows were aggregated
    df_agg["samples"] = df.groupby(["traj_dir"]).count().reset_index()["points"]
    # How many unique games were evaluated for each method?
    df_agg["games"] = df.groupby(["traj_dir"]).nunique().reset_index()["game"]
    df_agg.to_csv(args.results_file, index=False)

    # Print per-column summary
    print("\nRESULTS:")
    for col in df_agg.columns:
        if col == "traj_dir":
            continue
        print(f"{col}: {df_agg[col].mean():.3f}")
