import pandas as pd
import json
import numpy as np
from omegaconf import OmegaConf, DictConfig
import os
import yaml
from template_bash import bash_heads
from tqdm import tqdm
import glob
import shutil

df = pd.read_parquet('/path/to/data/csv')
captions = df['caption']
df_animal = df[~df['Animal'].isin(['None', None])]
vid_path = '/path/to/source/video/mp4'
frame_path = '/path/to/source/video/frames'


config_path = '/save/config/path'
save_path = '/output/path'


captions = list(df_animal['caption'])
editing_prompts = list(df_animal['CE_animal'])
keys = list(df_animal['key'])
model_name = "ckpt/name"
config_path = '/save/config/path'
save_path = '/output/path'

for i in tqdm(range(len(keys))):
    cur_name = keys[i]
    config_name = f'{config_path}/{cur_name}.yaml'
    train_video_path = f'sampled_videos/{cur_name}.mp4'
    train_prompt = f'{captions[i]}'
    video_length = 25 #@param {type:"number"}
    width = 480 #@param {type:"number"}
    height = 480 #@param {type:"number"}
    learning_rate = 3e-5 #@param {type:"number"}
    train_steps = 500 #@param {type:"number"}
    prompts = [editing_prompts[i]]
    output_dir = f"{save_path}/{cur_name}"
    config = {
    "pretrained_model_path": model_name,
    "output_dir": output_dir,
    "train_data": {
        "video_path": train_video_path,
        "prompt": train_prompt,
        "n_sample_frames": video_length,
        "width": width,
        "height": height,
        "sample_start_idx": 0,
        "sample_frame_rate": 1,
    },
    "validation_data": {
        "prompts": prompts,
        "video_length": video_length,
        "width": width,
        "height": height,
        "num_inference_steps": 20,
        "guidance_scale": 12.5,
        "use_inv_latent": True,
        "num_inv_steps": 50,
    },
    "learning_rate": learning_rate,
    "train_batch_size": 1,
    "max_train_steps": train_steps,
    "checkpointing_steps": 1000,
    "validation_steps": 100,
    "trainable_modules": [
        "attn1.to_q",
        "attn2.to_q",
        "attn_temp",
    ],
    "seed": 33,
    "mixed_precision": "fp16",
    "use_8bit_adam": False,
    "gradient_checkpointing": True,
    "enable_xformers_memory_efficient_attention": True,
    }
    OmegaConf.save(config, config_name)
prompt_bashes = []
for i in range(len(keys)):
    cur_bash = f'export CUDA_VISIBLE_DEVICES=$$$d$$$\naccelerate launch train_tuneavideo.py --config={config_path}/{keys[i]}.yaml'
    prompt_bashes.append(cur_bash)
device_id = [0, 3]
devide_count = len(device_id)
total_len = len(prompt_bashes)
length_per_device = total_len // devide_count
run_string = bash_heads['tav']

for i in range(devide_count):
    new_list = []
    start = i * length_per_device
    end = (i+1) * length_per_device if i != (devide_count-1) else total_len
    for j in range(start, end):
        new_list.append(prompt_bashes[j].replace('$$$d$$$', str(device_id[i])))
    with open(f'run_ce_animals_tav_p{device_id[i]}.sh', 'w') as s:
        s.write(run_string + '\n')
        s.write('\n\n'.join(new_list))
        s.close()