from PIL import Image
from nesim.vis.video import generate_video
import argparse
import os

parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--start", type=int, help="start index", default=0)
parser.add_argument("--stop", type=int, help="stop index", default=11970)
parser.add_argument("--step", type=int, help="step size", default=100)

args = parser.parse_args()

vis_root = "./neuron_atlas_video/checkpoint_featurevis"
frames_root = "./neuron_atlas_video/atlases"


checkpoint_filenames = [
    f"train_step_idx_{i}.pth" for i in range(args.start, args.stop, args.step)
]
print(f"start: {args.start} stop: {args.stop} step: {args.step}")
checkpoint_paths = [os.path.join("./step_checkpoints", f) for f in checkpoint_filenames]

frame_filenames = []
for filename, path in zip(checkpoint_filenames, checkpoint_paths):
    assert os.path.exists(path)

    frame_filename = os.path.join(frames_root, f"{filename}.jpg")

    frame_filenames.append(frame_filename)


list_of_pil_images = [Image.open(name) for name in frame_filenames]
generate_video(
    list_of_pil_images=list_of_pil_images,
    framerate=5,
    filename="neuron_atlas_vis_during_training.mp4",
    size=(1024, 1024),
)
