import os
from nesim.vis.video import generate_video
from PIL import Image

start_train_step = 10
end_train_step = 450
skip_existing = False

run_name = "apply_nesim_every_n_steps_10_nesim_config_scale_0.08_shrink_factor_[9.0]_layer_names_index_10_checkpoint_every_n_steps_10_num_warmup_steps_300_batch_size_128_context_length_256"
# run_name = "apply_nesim_every_n_steps_10_nesim_config_scale_0.08_shrink_factor_[5.0]_layer_names_index_10_checkpoint_every_n_steps_10_num_warmup_steps_300_batch_size_128_context_length_256"
layer_name = "transformer.h.10.mlp.c_fc"
target_category = "science"

output_folder = f"maps/softmax/{target_category}_{layer_name}_{run_name}"
os.system(f"mkdir -p {output_folder}")

video_output_folder = "./videos"
dataset_filename = "dataset.json"
checkpoints_root = "../../../training/gpt_neo_125m/checkpoints/"
filename = "pytorch_model.bin"

video_filename = f"softmax_start_{start_train_step}_stop_{end_train_step}_{layer_name}_{target_category}_{run_name}.mp4"

frame_filenames = []
train_step_indices = range(start_train_step, end_train_step, 10)
for frame_idx, train_step in enumerate(train_step_indices):

    print(f"\nGenerating frame: {frame_idx}/{len(train_step_indices)}")
    checkpoint_folder = f"checkpoint-{train_step}"
    image_filename = f"{train_step}_{layer_name}_{run_name}.png"

    checkpoint_filename = os.path.join(
        checkpoints_root, run_name, checkpoint_folder, filename
    )
    output_filename = os.path.join(output_folder, image_filename)
    command = f"""python3 generate_map_softmax.py \
--dataset-filename {dataset_filename} \
--checkpoint-filename {checkpoint_filename} \
--layer-name {layer_name} \
--output-filename {output_filename} \
--target-category {target_category}"""

    if skip_existing:
        if not os.path.exists(output_filename):
            os.system(command)
        else:
            print(f"Already exists: {output_filename}")
    else:
        os.system(command)
    frame_filenames.append(output_filename)

video_path = os.path.join(video_output_folder, video_filename)
generate_video(
    list_of_pil_images=[Image.open(f) for f in frame_filenames],
    framerate=10,
    size=(1500, 1100),
    filename=video_path,
)
print(f"Saved: {video_path}")
