import os
import re
import numpy as np
from collections import defaultdict

BASE_DIR = "."  # or specify the root directory

# Metrics to extract
METRICS = ["reward", "collision", "prey_tagged"]

# Storage: algo -> list of dicts
results = defaultdict(list)

for algo in os.listdir(BASE_DIR):
    algo_path = os.path.join(BASE_DIR, algo)
    if not os.path.isdir(algo_path):
        continue

    for seed in os.listdir(algo_path):
        seed_path = os.path.join(algo_path, seed)
        if not os.path.isdir(seed_path):
            continue

        for filename in os.listdir(seed_path):
            if filename.startswith("log") and filename.endswith(".txt"):
                filepath = os.path.join(seed_path, filename)

                with open(filepath, "r") as f:
                    lines = f.readlines()

                current_result = {}
                for line in lines:
                    line = line.strip()
                    if line.startswith("reward:"):
                        try:
                            current_result["reward"] = float(line.split(":")[1].strip())
                        except ValueError:
                            continue
                    elif line.startswith("{") and "}" in line:
                        try:
                            info = eval(line)
                            for k in METRICS:
                                if k in info:
                                    current_result[k] = float(info[k])
                        except Exception:
                            continue

                if current_result:
                    results[algo].append(current_result)

# Aggregate and print results
print("{:<10} {:<10} {:<15} {:<15}".format("Algorithm", "Metric", "Mean", "Std"))
print("=" * 50)
for algo, runs in results.items():
    for metric in METRICS:
        vals = [run[metric] for run in runs if metric in run]
        if vals:
            mean = np.mean(vals)
            std = np.std(vals)
            print(f"{algo:<10} {metric:<10} {mean:<15.2f} {std:<15.2f}")