import csv
import argparse
import os

def print_evaluate_seed(csv_path: str, out_path: str = None):
    """
    Reads a CSV file generated by inspect_minari.py and writes the unique
    seed_or_layout values for all episodes in the 'test' split to a text file.
    Also prints a short summary to stdout.
    """
    if not os.path.isfile(csv_path):
        print(f"Error: File not found at '{csv_path}'")
        return

    evaluate_seed = []
    try:
        with open(csv_path, mode='r', newline='') as f:
            reader = csv.DictReader(f)
            for row in reader:
                if (row.get('split', '') or '').strip().lower() == 'evaluate':
                    seed = (row.get('seed_or_layout') or '').strip()
                    if seed:
                        evaluate_seed.append(seed)
    except Exception as e:
        print(f"An error occurred while reading the file: {e}")
        return

    unique_evaluate_seed = sorted(set(evaluate_seed))
    if not unique_evaluate_seed:
        print("No test seeds found in the specified file.")
        return

    if out_path is None:
        base_dir = os.path.dirname(os.path.abspath(csv_path))
        out_path = os.path.join(base_dir, "minigrid_evaluate_seed.txt")

    try:
        with open(out_path, "w", encoding="utf-8") as out_f:
            out_f.write("[\n")
            out_f.write(", ".join(f"{seed}" for seed in unique_evaluate_seed))
            out_f.write("\n]")
    except Exception as e:
        print(f"Failed to write seeds to '{out_path}': {e}")
        return

    print(f"Wrote {len(unique_evaluate_seed)} unique test seeds/layouts to: {out_path}")
    preview = ", ".join(unique_evaluate_seed)
    print(f"Preview (up to 200): {preview}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract and save test seeds from a split CSV file.")
    parser.add_argument("csv_file", type=str, help="Path to the episode_splits.csv file.")
    parser.add_argument("--out", type=str, default=None, help="Path to output .txt file (one seed per line).")
    args = parser.parse_args()

    print_evaluate_seed(args.csv_file, args.out)
