from __future__ import annotations

import argparse
import os
import glob
import csv
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt


def parse_csv(path):
    ep_returns = []
    ep_costs = []
    with open(path, "r") as f:
        reader = csv.reader(f)
        for row in reader:
            if not row:
                continue
                                      
            step = int(row[0])
            kv = dict(zip(row[1::2], row[2::2]))
            if "ep/return" in kv:
                try:
                    ep_returns.append(float(kv["ep/return"]))
                except Exception:
                    pass
            if "ep/cost" in kv:
                try:
                    ep_costs.append(float(kv["ep/cost"]))
                except Exception:
                    pass
    return ep_returns, ep_costs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--runs", type=str, default="runs/gdc")
    parser.add_argument("--out", type=str, default="results.png")
    args = parser.parse_args()

                                  
    data = defaultdict(list)
    for agent_dir in glob.glob(os.path.join(args.runs, "*")):
        metrics_csv = os.path.join(agent_dir, "metrics.csv")
        if os.path.isfile(metrics_csv):
            ep_returns, ep_costs = parse_csv(metrics_csv)
            if len(ep_returns) > 0 and len(ep_costs) > 0:
                                                         
                R = np.mean(ep_returns[-10:])
                C = np.mean(ep_costs[-10:])
                agent = os.path.basename(agent_dir).split("_")[0]
                data[agent].append((R, C))

                         
    plt.figure(figsize=(8, 6))
    for agent, points in data.items():
        arr = np.array(points)
        plt.scatter(arr[:, 0], arr[:, 1], label=agent)
    plt.xlabel("Return (avg last 10 eps)")
    plt.ylabel("Cost (avg last 10 eps)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(args.out)
    print(f"Saved plot to {args.out}")


if __name__ == "__main__":
    main()
