
import pandas as pd
import json
import numpy as np
from omegaconf import OmegaConf, DictConfig
import os
import yaml
from template_bash import bash_heads
from tqdm import tqdm
import glob
import shutil

df = pd.read_parquet('/path/to/data/csv')
captions = df['caption']
df_animal = df[~df['Animal'].isin(['None', None])]
vid_path = '/path/to/source/video/mp4'
frame_path = '/path/to/source/video/frames'

captions = list(df_animal['caption'])
prompts = list(df_animal['CE_animal'])
keys = list(df_animal['key'])
animals = list(df_animal['Animal'])
animals = [item.split(', ') for item in animals]
editing_entity = list(df_animal['CE_animal_editing_entity'])
video_files = [os.path.join(vid_path, f'{key}.mp4') for key in keys]
frame_files = [os.path.join(frame_path, key) for key in keys]

config_path = '/save/config/path'
base_config = '/path/to/FateZero/config/base.yaml'

starts = []
for i in range(len(frame_files)):
    base_conf = OmegaConf.load(base_config)
    base_conf['dataset_config']['path'] = frame_files[i]
    base_conf['dataset_config']['prompt'] = captions[i]
    base_conf['dataset_config']['n_sample_frame'] = 25
    base_conf['dataset_config']['sampling_rate'] = 1
    base_conf['editing_config']['editing_prompts'] = [prompts[i]]
    base_conf['editing_config']['p2p_config'][0]['eq_params']['words'] = editing_entity[i].split(", ")
    save_yaml_path = os.path.join(config_path, f'{keys[i]}.yaml')
    with open(save_yaml_path, 'w') as yaml_file:
        OmegaConf.save(config=base_conf, f=yaml_file)

prompt_bashes = []
for i in range(len(frame_files)):
    cur_bash = f"export CUDA_VISIBLE_DEVICES=$$$d$$$\naccelerate launch test_fatezero.py --config {config_path}/{keys[i]}.yaml"
    prompt_bashes.append(cur_bash)

device_id = [1, 2]
devide_count = len(device_id)
total_len = len(prompt_bashes)
length_per_device = total_len // devide_count
run_string = bash_heads['fatezero']

for i in range(devide_count):
    new_list = []
    start = i * length_per_device
    end = (i+1) * length_per_device if i != (devide_count-1) else total_len
    for j in range(start, end):
        new_list.append(prompt_bashes[j].replace('$$$d$$$', str(device_id[i])))
    with open(f'run_ce_animals_fatezero_p{device_id[i]}.sh', 'w') as s:
        s.write(run_string + '\n')
        s.write('\n\n'.join(new_list))
        s.close()
