import os
from nesim.vis.video import generate_video
from PIL import Image

start_train_step = 10
end_train_step = 470

# run_name = "apply_nesim_every_n_steps_10_nesim_config_scale_0.08_shrink_factor_[9.0]_layer_names_index_10_checkpoint_every_n_steps_10_num_warmup_steps_300_batch_size_128_context_length_256"
run_name = "apply_nesim_every_n_steps_10_nesim_config_scale_0.08_shrink_factor_[5.0]_layer_names_index_10_checkpoint_every_n_steps_10_num_warmup_steps_300_batch_size_128_context_length_256"
layer_name = "transformer.h.10.mlp.c_fc"

output_folder = "maps/multiple_category/"
video_output_folder = "./videos"
dataset_filename = "dataset.json"
checkpoints_root = "../../../training/gpt_neo_125m/checkpoints/"
filename = "pytorch_model.bin"
skip_existing = False

video_filename = f"multiple_category_start_{start_train_step}_stop_{end_train_step}_{layer_name}_{run_name}.mp4"

frame_filenames = []
train_step_indices = range(start_train_step, end_train_step, 10)
for frame_idx, train_step in enumerate(train_step_indices):

    print(f"\nGenerating frame: {frame_idx}/{len(train_step_indices)}")
    checkpoint_folder = f"checkpoint-{train_step}"
    image_filename = f"{train_step}_{layer_name}_{run_name}.png"

    checkpoint_filename = os.path.join(
        checkpoints_root, run_name, checkpoint_folder, filename
    )
    output_filename = os.path.join(output_folder, image_filename)
    command = f"""python3 generate_map_multiple_categories.py \
--dataset-filename {dataset_filename} \
--checkpoint-filename {checkpoint_filename} \
--layer-name {layer_name} \
--output-filename {output_filename}"""

    if skip_existing:
        if not os.path.exists(output_filename):
            os.system(command)
        else:
            print(f"Already exists: {output_filename}")
    else:
        os.system(command)
    frame_filenames.append(output_filename)

generate_video(
    list_of_pil_images=[Image.open(f) for f in frame_filenames],
    framerate=10,
    size=(1500, 800),
    filename=os.path.join(video_output_folder, video_filename),
)
