import pandas as pd
import json
import numpy as np
from omegaconf import OmegaConf, DictConfig
import os
import yaml
from template_bash import bash_heads
from tqdm import tqdm
import glob
import shutil

df = pd.read_parquet('/path/to/data/csv')
captions = df['caption']
df_animal = df[~df['Animal'].isin(['None', None])]
vid_path = '/path/to/source/video/mp4'
frame_path = '/path/to/source/video/frames'

captions = list(df_animal['caption'])
prompts = list(df_animal['CE_animal'])
keys = list(df_animal['key'])
animals = list(df_animal['Animal'])
animals = [item.split(', ') for item in animals]
editing_entity = list(df_animal['CE_animal_editing_entity'])
video_files = [os.path.join(vid_path, f'{key}.mp4') for key in keys]
frame_files = [os.path.join(frame_path, key) for key in keys]

preprocess_list = []
prompt_bashes = []
device = 'cuda:6'
work_path = '/path/to/TokenFlow'
run_string = bash_heads['tokenflow']

for i in range(len(video_files)):
    cur_preprocess = f"python3.9 preprocess.py --data_path {video_files[i]} --inversion_prompt \"{captions[i]}\" --sd_version 2.1 --n_frames 25 --device cuda:$$$d1$$$"
    preprocess_list.append(cur_preprocess)
    with open(os.path.join(work_path, 'configs/config_pnp.yaml'), "r") as f:
        default_config = yaml.safe_load(f)
    default_config['device'] = 'cuda:0'
    default_config['data_path'] = f'data/{keys[i]}'
    default_config['n_frames'] = 25
    default_config['prompt'] = f"{prompts[i]}"
    with open(os.path.join(work_path, f'batch_configs_ce_animals/{keys[i]}.yaml'), "w") as f:
        yaml.dump(default_config, f)
    cur_bash = f'export CUDA_VISIBLE_DEVICES=$$$d2$$$\npython3.9 run_tokenflow_pnp.py --config_path batch_configs_ce_animals/{keys[i]}.yaml'
    prompt_bashes.append(cur_bash)

devide_count = 1
total_len = len(prompt_bashes)
length_per_device = total_len // devide_count

for i in range(devide_count):
    new_list = []
    start = i * length_per_device
    end = (i + 1) * length_per_device if i != (devide_count - 1) else total_len
    for j in range(start, end):
        new_list.append(prompt_bashes[j].replace('$$$d1$$$', str(i)).replace('$$$d2$$$', str(i)))
    with open(f'run_animal_tokenflow_p{i}.sh', 'w') as s:
        s.write(run_string + '\n')
        s.write('\n\n'.join(new_list))
        s.close()