import os


if __name__ == "__main__":
    log_dir = "training-runs"
    for subdir in os.listdir(log_dir):
        root = os.path.join(log_dir, subdir)
        if not os.path.isdir(root):
            continue
        steps1 = [int(x[17:-4]) for x in os.listdir(root) if x.startswith("network-snapshot-")]
        steps2 = [int(x[15:-3]) for x in os.listdir(root) if x.startswith("training-state-")]
        if len(steps1) <= 1 or len(steps2) <= 1:
            continue
        step = min(max(steps1), max(steps2))

        for s in steps1:
            if s < step:
                os.remove(os.path.join(root, f"network-snapshot-{s:06d}.pkl"))
        for s in steps2:
            if s < step:
                os.remove(os.path.join(root, f"training-state-{s:06d}.pt"))
        assert os.path.isfile(os.path.join(root, f"training-state-{step:06d}.pt")), root
        assert os.path.isfile(os.path.join(root, f"network-snapshot-{step:06d}.pkl")), root