# Basics 
import argparse, os, sys, glob, shutil, time 
import warnings;warnings.filterwarnings('ignore')
import torch, torchvision
import numpy as np
from omegaconf import OmegaConf
from PIL import Image, ImageDraw
import subprocess
import re, io
import json 
import time 
from functools import reduce
import base64
from tqdm import tqdm, trange
from pprint import pprint
from copy import deepcopy
import openai 
from openai import AzureOpenAI
from dotenv import load_dotenv
import matplotlib.pyplot as plt 
from lightning_fabric import seed_everything
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from transformers import AutoProcessor, Blip2ForConditionalGeneration

# Our utils 
from main_utils import * 

# NOTE added to fix t2v-turbo random seed 
from lightning_fabric import seed_everything

# T2V-turbo 
from utils.lora import collapse_lora, monkeypatch_remove_lora
from utils.lora_handler import LoraHandler
from utils.common_utils import load_model_checkpoint
from utils.utils import instantiate_from_config
from scheduler.t2v_turbo_scheduler import T2VTurboScheduler
from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline
from pipeline.multidiffusion import *             #Multidiffusion_T2VTurboVC2Pipeline    # NOTE ours 
import nltk
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

# python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'


if __name__ == '__main__' : 
    parser = argparse.ArgumentParser(description="DSG grounding")
    parser.add_argument("--eval_section", type=str, default='count', help="evalcrafter section=[action, amp, color, count, face, none, text]")
    parser.add_argument("--model", type=str, default='t2vturbo', help="t2v generation model")
    parser.add_argument("--output_root", type=str, help="save path root")
    parser.add_argument("--seed", type=int, default=123, help="generation seed")
    parser.add_argument("--dsg_type", type=str, default='dependency', help="['dependency', 'icl_count', 'origin']")                
    parser.add_argument("--load_molmo", action ="store_true", help="if you wanna use Molmo for pointing") 
    parser.add_argument("--data", default ="evalcrafter", help="benchmark") 

    parser.add_argument("--selection_score", default ="dsg_blip", help="segmentation models") 
    parser.add_argument("--reverse", action ="store_true", help="reverse order for the fast generation") 
    parser.add_argument("--not_initprompt_for_background", action = 'store_true', help='use background prompt as paraphrased prompt')

    parser.add_argument("--round", type=int, default=1, help="iteration round")
    parser.add_argument("--k", type=int, default=5, help="# of video candidates")
    parser.add_argument("--div_seeds", action ="store_true", help="use diverse seeds for multi-round") 


    args = parser.parse_args()

    EVAL_SECTION = args.eval_section ; SEED = args.seed  ; ROUND_NUM = args.round 
    print('< ', EVAL_SECTION, ' section is start!! >')

    # if args.k == 1 and args.round == 5 : 
        # seeds_list = []

    ## Load dsg questions (previously saved)
    if args.data == 'evalcrafter' : 
        dsg_json_path = 'datasets/our_dsg_depend_v2/dsg_' + EVAL_SECTION + '.json'   

    elif args.data == 'vbench' : 
        dsg_json_path = 'datasets/our_dsg_depend_v2/vbench/' + EVAL_SECTION + '.json'   
    
    elif args.data == 'compbench' : 
        dsg_json_path = 'datasets/our_dsg_depend_v2/compbench/' + EVAL_SECTION + '.json'   
    
    with open(dsg_json_path, 'r') as json_file:
        meta_data = json.load(json_file)

    ## for score selection model 
    if 'blip' in args.selection_score : 
        device_blip = torch.device('cuda')
        blip2_processor = AutoProcessor.from_pretrained("/checkpoints/blip2-opt-2.7b")
        blip2_model = Blip2ForConditionalGeneration.from_pretrained("/checkpoints/blip2-opt-2.7b", torch_dtype=torch.float16).to(device_blip)#.to('cuda')


    ## Root setting 
    cur_result_root = os.path.join(args.output_root, EVAL_SECTION) + '/'
    print(cur_result_root)
    if not os.path.exists(cur_result_root):  
        os.makedirs(cur_result_root)

    ## Model setup 
    if args.model == 't2vturbo' :      
        torch.manual_seed(SEED)
        seed_everything(SEED)  

        start_time = time.time()
        unet_dir         = 'unet_lora.pt'
        videocrafter_dir = 'base_512_v2/model.ckpt'

        config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")    # change to your path
        model_config = config.pop("model", OmegaConf.create())
        pretrained_t2v = instantiate_from_config(model_config)
        pretrained_t2v = load_model_checkpoint(pretrained_t2v, videocrafter_dir)

        unet_config = model_config["params"]["unet_config"]
        unet_config["params"]["time_cond_proj_dim"] = 256
        unet = instantiate_from_config(unet_config)
        unet.load_state_dict(
            pretrained_t2v.model.diffusion_model.state_dict(), strict=False
        )
        use_unet_lora = True
        lora_manager = LoraHandler(
            version="cloneofsimo",
            use_unet_lora=use_unet_lora,
            save_for_webui=True,
            unet_replace_modules=["UNetModel"],
        )
        lora_manager.add_lora_to_model(
            use_unet_lora,
            unet,
            lora_manager.unet_replace_modules,
            lora_path=unet_dir,
            dropout=0.1,
            r=64,
        )
        unet.eval()
        collapse_lora(unet, lora_manager.unet_replace_modules)
        monkeypatch_remove_lora(unet)

        pretrained_t2v.model.diffusion_model = unet
        scheduler = T2VTurboScheduler(
            linear_start=model_config["params"]["linear_start"],
            linear_end=model_config["params"]["linear_end"],
        )
        # initial model pipeline load 
        pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config).to("cuda")
        multidiffusion_pipeline = Multidiffusion_T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config).to("cuda")   # NOTE multidiffusion pipeline is added  

        torch.manual_seed(args.seed)
        seed_everything(args.seed)      

        print('T2V turbo model loading time (min) : ', (time.time() - start_time)//60)



    if args.load_molmo: 
        processor = AutoProcessor.from_pretrained(
        'allenai/MolmoE-1B-0924',
        # 'allenai/Molmo-7B-D-0924',
        trust_remote_code=True,
        torch_dtype='auto',
        device_map='auto'
        )
        # load the model
        molmo = AutoModelForCausalLM.from_pretrained(
            'allenai/MolmoE-1B-0924',
            # 'allenai/Molmo-7B-D-0924',
            trust_remote_code=True,
            torch_dtype='auto',
            device_map='auto'
        )#.to(device_for_molmo)


    if args.reverse : 
        print('Reverse manner to make fast!')
        meta_data.reverse() 


    cur_times_zip = []
    start_time = time.time()

    ## per prompt 
    for INIT_PROMPT_IDX in range(len(meta_data) ) :      # len(meta_data)

        torch.manual_seed(SEED)
        seed_everything(SEED)  

        cur_meta = meta_data[INIT_PROMPT_IDX]              
        init_prompt = cur_meta['origin_prompt']              
        unique_id = cur_meta['idx']
        del cur_meta["origin_prompt"]
        del cur_meta["idx"]
        
        if 'mesmerizing Northern Lights' not in init_prompt : 
            continue 

        if args.dsg_type == 'dependency' : 
            qid2question = cur_meta['qid2question']
            qid2tuple = cur_meta['qid2tuple']
            qid2dependency = cur_meta['qid2dependency']            # dependency between questions 
        else : 
            qid2question = cur_meta

        print('< Initial prompt > : ', init_prompt)

        cur_gen_best_dir  = os.path.join(cur_result_root, str(unique_id), 'best_videos_' + args.selection_score)    # directory per prompts  
        os.makedirs(cur_gen_best_dir, exist_ok=True) 

        # NOTE if there are DSG answers, pass 
        if os.path.exists(os.path.join(cur_gen_best_dir, 'dsg_log.txt')) :    
            print(str(unique_id), ' is already existed')
            continue 
            
        DSG_logs = []

        cur_start_time = time.time()
        
        ## per round 
        for cur_round in range(ROUND_NUM) : 

            print(f'{cur_round} round is started')

            # NOTE diverse seed was added 
            if args.div_seeds: 
                seed_fields = [random.randint(0, 100000) for _ in range(args.k * ROUND_NUM)] 
                SEED_LISTS = seed_fields[args.k * cur_round : args.k * (cur_round+1)]

            else : 
                SEED_LISTS = [random.randint(0, 100000) for _ in range(args.k)] #[random.randint(1, 1000000)] 
            print('SEED_LISTS : ', SEED_LISTS)


            cur_gen_dir  = os.path.join(cur_result_root, str(unique_id), str(cur_round) + '_round')    # directory per prompts  
            os.makedirs(cur_gen_dir, exist_ok=True) 

            # if first round 
            if cur_round == 0 : 
                init_video_path = os.path.join(cur_gen_dir, init_prompt + '_0.mp4')
                noise_path = None 
            else : 
                init_video_path = os.listdir(cur_gen_dir)[0]
                noise_seed = re.search(r'_(\d+)\.mp4$', init_video_path).group(1)
                prev_gen_dir = cur_gen_dir.replace(str(cur_round)+ '_round', str(cur_round-1)+ '_round')
                noise_path = os.path.join( prev_gen_dir,  'init_latent_' + str(noise_seed) + '.pt')

            # import pdb;pdb.set_trace()

            # NOTE if there are DSG answers, pass 
            if os.path.exists(os.path.join(cur_gen_dir, 'binary_yes_mask.png')) :    
                print(str(unique_id), ' is already existed')
                continue 
            

            # For the first round -> T2V generation 
            if cur_round == 0 : 
                print(f'---------------- 1. T2V generation ----------------')
                if not os.path.exists(init_video_path) :        # if not exists 
                    if args.model == 't2vturbo' : 
                        T2VTurbo_from_each_prompt_wo_bottleneck_onebyone(outpath = cur_gen_dir, 
                                                                        prompt_text = init_prompt, 
                                                                        seeds = SEED, pipeline = pipeline)      
                    elif args.model == 'videocrafter2' : 
                        with open(os.path.join(cur_gen_dir, 'init_prompt.txt'), 'w', encoding='utf-8') as file:   # Init prompt dir 
                            file.write(init_prompt)

                        VideoCrafter_from_each_prompt(outpath = cur_gen_dir, 
                                                        prompts_path = os.path.join(cur_gen_dir, 'init_prompt.txt'), 
                                                        seeds = SEED
                                                        )

                    elif args.model == 'vico' : 
                        Vico_from_each_prompt(outpath = init_video_path, 
                                                prompt_text = init_prompt, 
                                                seeds = SEED, pipe = pipe)      



            print(f'---------------- 2. DSG scoring w/ GPT4o ----------------')    

            # NOTE key-object extraction from DSG entities 
            object_string = qid2tuple['custom_0']['output']
            lines = object_string.split("\n")
            Q_type = [] ; key_objects_from_Q = []

            # Question type stacking 
            for line in lines:
                category = line.split('|')[1].split('-')[0].strip()
                Q_type.append(category)

                # Key-object extraction 
                if category in ['other', 'entity']:
                    match = re.search(r'\(([^)]+)\)', line)
                    if match:
                        nouns = match.group(1).split(',')[0].strip()  
                        key_objects_from_Q.append(nouns)

            print('Question types: ', Q_type)
            print('Key objects: ', key_objects_from_Q)

            key_objects_in_questions = []
            for q in qid2question.values():
                is_key = False
                for obj in key_objects_from_Q:
                    if obj in q:
                        key_objects_in_questions.append(obj)
                        is_key = True 

                if not is_key : 
                    key_objects_in_questions.append(None)


            # DSG question asking 
            # video_path = os.path.join(cur_gen_dir, init_video_path )
            if cur_round == 0 : 
                video_path = init_video_path
            else : 
                video_path = os.path.join(cur_gen_dir, init_video_path )

            first_frame_img = extract_first_frame(video_path, cur_gen_dir)  
            first_frame_img_gpt = encode_gpt4_input(first_frame_img)

            # NOTE question-wise gpt4o calling 
            dsg_answers = []
            for i in range(len(qid2question)) : 
                cur_question = qid2question[str(i+1)]   
                cur_question_type = Q_type[i]
                key_objects = key_objects_in_questions[i]         
                system_prompt = f'You are an expert at answering questions about the content of a given image.'

                # Devide count prompt & non-count prompt 
                count_prompt = f'''
                                1. Given the question: "{cur_question}", provide a brief reasoning (up to two sentences) to determine the accurate answer.
                                2. Respond to the question using binary values: 1.0 for "Yes" and 0.0 for "No". If the answer is uncertain or unnatural due to image distortion or other issues, respond with 0.0 ("No").
                                3. Return the number of "{key_objects}" (as an integer) mentioned in the initial prompt "{cur_question}". 
                                4. Return the number of "{key_objects}" (as an integer) in the provided image.

                                Return the result as a dictionary in the following format (not in JSON format):
                                {{
                                    "Q": "<question>",
                                    "A": <binary answer>,
                                    "reasoning": "<brief reasoning>",
                                    "obj_in_prompt": <number of key object mentioned in the initial prompt>,
                                    "obj_in_img": <number of key object in the image>,
                                }}

                                Example: 
                                {{
                                    "Q": "Is there one robot?",
                                    "A": 0.0,
                                    "reasoning": "There are two visible robots in the image.",
                                    "obj_in_prompt": 1,
                                    "obj_in_img": 2,
                                }}

                                Please provide only the dictionary as the output without any additional text or explanation.
                                '''

                non_count_prompt = f'''
                                    Respond to "{cur_question}" using binary values: 1.0 for Yes and 0.0 for No. If the answer is uncertain due to image distortion or other issues, respond with 0.0 (No). \
                                    Return the result as a dictionary in the following format (not in JSON format): \        
                                    {{"Q": "<question>", "A": <binary answer>}} \
                                    (e.g., {{"Q": "Is there one robot?", "A": 0.0}}) \
                                    Provide only the dictionary as the output, without any additional text or explanations.
                                    '''     # no reasoning 

                success = False
                while not success:
                    try : 
                        if cur_question_type == 'other' :          
                            answer = asking_gpt4o(system_prompt, count_prompt, first_frame_img_gpt)
                        else : 
                            answer = asking_gpt4o(system_prompt, non_count_prompt, first_frame_img_gpt)

                        print(answer) ; print('*' * 5)
                        success = True 

                    except : 
                        print('ERROR..')
                        time.sleep(9)

                try : 
                    answer = answer.replace('"', '\\"')
                    answer_dict = eval(answer.replace('\n', '').replace('```json', '').replace('```', '').replace('\\', '').replace('```python', ''))
                except : 
                    # import pdb;pdb.set_trace()
                    continue 

                dsg_answers.append(answer_dict)
                
            # consider DSG dependency  
            if args.dsg_type == 'dependency' : 
                try : 
                    qid2scores, qid2validity, dsg_answers = filter_DSG_answer_w_dependency(dsg_answers, qid2dependency) 
                    print('Updated DSG score: ', qid2scores)        
                    print('Updated logs: ', qid2validity)
                except : 
                    dsg_answers = dsg_answers         # error -> not consider dependency 

            # NOTE Object-wise question collection 
            object_wise_dict = {}
            for obj in key_objects_from_Q : 
                cur_obj_qas = []
                for cur_qa in dsg_answers:
                    if (obj in cur_qa['Q']) or (re.search(r'\b' + r'\b|\b'.join(obj.split()) + r'\b', cur_qa['Q'], re.IGNORECASE)) :      
                        cur_obj_qas.append(cur_qa)
                object_wise_dict[obj] = cur_obj_qas
            
            q_count = sum(len(qa_list) for qa_list in object_wise_dict.values())
            dsg_score = sum(float(qa['A']) for qa in dsg_answers) / len(dsg_answers)

            # write dsg answers 
            f = open(os.path.join(cur_gen_dir, 'logs.txt'), 'w')             
            f.write('\n'.join(list(map(str, dsg_answers))))
            f.close()

            # print(f'---------------- 3. Object-wise preserving area! ----------------') 
            preserve_object = None ; preserve_num = None      
            preserve_prompts = [] ; local_prompts = []

            # object_wise_dict = sort_priority_of_object(object_wise_dict)

            '''
            Q1. Which object we should preserve? 
            '''
            # NOTE gpt4o ver update 
            task_prompt_key = (
                # "Our goal is to preserve this correctly-generated object and regenerate the others in the scene. "
                f"Given the generated image and the list of question-answer pairs for each object, represented as {object_wise_dict}, "
                "choose the most accurately or visibly generated object from the list {key_objects_from_Q}. "
                "Prioritize selecting objects with a high number of answers rated 1.0 for each question."
                "Select the object that is both large and clearly visible, prioritizing prominent objects (such as animals, humans, or specific items) over background elements (like ocean or city). "
                "Return only the name of the best object to keep from the list, without additional explanation (e.g., 'dog')."

            )

            stop = False ; error_count = 0 
            while not stop:
                try : 
                    local_response = client.chat.completions.create( 
                                                    # model="gpt-4-0125",
                                                    model="gpt-4o",    
                                                    messages=[
                                                        {
                                                            "role": "user",
                                                            "content": [
                                                                {"type": "text", "text":  task_prompt_key}, 
                                                                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{first_frame_img_gpt}"}},
                                                            ],
                                                        }
                                                    ],
                                                    max_tokens=100,
                                                )
                    preserve_object = local_response.choices[0].message.content  
                    stop = True 

                except : 
                    print('ERROR..')
                    # import pdb;pdb.set_trace()
                    time.sleep(9)
                    error_count += 1 
                    if error_count > 3 :           # NOTE if error count > 3 : stop 
                        preserve_object = None 
                        stop = True   

            '''
            Q2. How many object we should preserve / add / delete? 
                => Use DSG existence questions 
            '''
            count_priority = None 
            preserve_questions = [item['Q'] for item in object_wise_dict.get(preserve_object, [])]
            # NOTE local_questions = [item['Q'] for key, items in object_wise_dict.items() if key != preserve_object for item in items]
            local_questions = [q for q in qid2question.values() if q not in preserve_questions]

            try : 
                q_logs = object_wise_dict[preserve_object]
            except : 
                continue      

            for log in q_logs :  

                # count question 
                if (len(log) > 2) : 
                    # Yes-answered 
                    if (log['obj_in_prompt'] == log['obj_in_img'])  :          
                        preserve_num = log['obj_in_img']
                        count_priority = 1 
                        print(f'* [Priority 1] Preserved object : {preserve_object} | Preserved num : {preserve_num}')

                    # No-answered 
                    elif (log['obj_in_prompt'] < log['obj_in_img']) :        
                        preserve_num = log['obj_in_prompt'] 
                        count_priority = 2
                        print(f'* [Priority 2] Preserved object : {preserve_object} | Preserved num : {preserve_num}')

                    elif (log['obj_in_prompt'] > log['obj_in_img']) and (log['obj_in_img'] > 0) :  
                        preserve_num = log['obj_in_img'] 
                        lack_num_of_object = log['obj_in_prompt'] -  log['obj_in_img'] 
                        local_questions.append(f'{lack_num_of_object} {preserve_object}')
                        count_priority = 3 

                        print(f'* [Priority 3] Preserved object : {preserve_object} | Preserved num : {preserve_num}')
                        print(f'* {lack_num_of_object} number of object {preserve_object} need to generate more')

                    # log['obj_in_img'] == 0 -> Need to generate all video again 
                    else : 
                        local_questions += preserve_questions
                        count_priority = 4 

            '''
            Q3. How to localize which object need to delete? 
                => Use pointing prompt, pointing models 
            '''
            pointing_prompt = None 
            if preserve_object != None : 

                if preserve_num != None : 
                    pointing_prompt = f'Point the biggest {preserve_num} {preserve_object}.'
                    # pointing_prompt = f'Point {preserve_num} {preserve_object}.'     # the biggest 
                else : 
                    pointing_prompt = f'Point the biggest {preserve_object}.'
                    

            # logging 
            output = '\n' + '=' * 50 + '\n' 
            output += f'* [DSG score]: {dsg_score}\n'
            output += f'* [Object decision] Preserved object : {preserve_object} | Preserved num : {preserve_num} | Priority: {count_priority}\n'
            output += f'* [Pointing prompt]: {pointing_prompt}\n'
            output += f'* [DSG questions for preserving prompts]: [{", ".join(preserve_questions)}]\n'
            output += f'* [DSG questions for local prompts]: [{", ".join(local_questions)}]\n'
            output += '=' * 50 + '\n'

            file_path = os.path.join(cur_gen_dir, 'logs.txt')
            with open(file_path, 'a') as f:
                f.write(output)

            print(output)

            DSG_logs.append(dsg_score)

            if dsg_score == 1.0 :         # NOTE DSG== 1.0 pass  
                # continue 
                break 


            print(f'---------------- 4. Pointing and Mask generation! ----------------') 
            remaining_object = [i for i in key_objects_from_Q if i != preserve_object]

            # Molmo pointing 
            # try : 
            viz_path = os.path.join(cur_gen_dir, 'molmo_point.png')
            object_points_list = ask_molmo(processor, molmo, first_frame_img, pointing_prompt, viz_path)     # Ask Molmo pointing 
            print('* Object point: ', object_points_list)

            if len(object_points_list) > 0 and (len(local_questions) != 0) and (len(remaining_object) != 0) :     # (dsg_score != 0.0) and 
                point2mask_semanticsam(img_path=os.path.join(cur_gen_dir,'first_frame.jpg'), 
                                        point_lists=object_points_list, 
                                        mask_save_path=os.path.join(cur_gen_dir, 'ssam_mask_stack.npy'), 
                                        img_width=512, img_height=320)
                all_masks = np.load(os.path.join(cur_gen_dir, 'ssam_mask_stack.npy'))
                total_mask = reduce(np.logical_or, all_masks).squeeze()     # mask combine (n, 320, 512)
                total_binary_mask = np.where(total_mask, 0, 255).astype(np.uint8)    

            else : 
                total_binary_mask = np.full((320, 512), 255, dtype=np.uint8)   
                local_questions = preserve_questions + local_questions       

            mask_paths = [os.path.join(cur_gen_dir, 'binary_yes_mask.png')]
            total_binary_mask = Image.fromarray(total_binary_mask)
            total_binary_mask.save(mask_paths[0])


            print(f'---------------- 5. Background / Local prompt generation! ----------------') 

            preserve_prompt_black = prompt_generator_from_Q_v4(question_list = preserve_questions) 
            # The case we should include preserve object (we need to add more)

            if (len(object_points_list) == 0) or len(remaining_object) == 0 or len(local_questions) == 0 :    # (dsg_score == 0.0)  or
                change_prompt_white = paraphrasing_prompt(init_prompt)[0]

            else : 
                if count_priority == 3 :      
                    change_prompt_white = prompt_generator_from_Q_v4(question_list = local_questions) 

                else : 
                    change_prompt_white = prompt_generator_from_Q_v4(question_list = local_questions) 


            t = '=' * 50 + '\n'
            # t += f'* Preserving prompt : {preserve_prompt_black}\n'
            t += f'* Regenerating prompt : {change_prompt_white}\n'
            t += '=' * 50 + '\n'
            print(t)

            file_path = os.path.join(cur_gen_dir, 'logs.txt')
            with open(file_path, 'a') as f:
                f.write(t)

            print(f'---------------- 6. Regenerate with MultiDiffusion manner ----------------') 

            all_prompt = [preserve_prompt_black, change_prompt_white]           # all_prompt = ['kitchen', 'bear making pizza']     
            all_prompt_origin = [init_prompt, change_prompt_white]    

            # filename length limits 
            if len(change_prompt_white) > 30:          
                suffix=""
            elif change_prompt_white == 'None' :         
                change_prompt_white = init_prompt
            else :  
                suffix = change_prompt_white

            file_names = []

            for ith_video, new_seed in enumerate(SEED_LISTS) : 

                print(f'{ith_video} th video generation')

                cur_file = init_prompt + '_' + suffix+'_'+ str(new_seed)
                file_names.append(cur_file)
                
                if args.model == 't2vturbo' : 

                    # Test origin prompt 
                    T2VTurbo_from_each_prompt_MultiDiffusion(outpath = cur_gen_dir, 
                                                    round_num = cur_round, 
                                                    noise_map = noise_path, 
                                                    prompt_text = init_prompt,            # NOTE background prompt 
                                                    seeds = SEED, 
                                                    pipeline = multidiffusion_pipeline, 
                                                    all_prompt = all_prompt_origin,    # NOTE local prompt  
                                                    local_seed = new_seed, 
                                                    mask_path = mask_paths,                    # layout generation mask 
                                                    suffix = suffix+'_'+ str(new_seed)     # file name suffix 
                                                    )

                elif args.model == 'videocrafter2' : 
                    
                    # all_prompt = all_prompt_origin

                    # save all_prompt to txt file 
                    prompts_path = os.path.join(cur_gen_dir, 'all_prompts.txt')
                    with open(prompts_path, "w") as file:
                        for prompt in all_prompt_origin:
                        # for prompt in all_prompt:
                            file.write(f"{prompt}\n")

                    VideoCrafter_Multidiffusion(outpath = cur_gen_dir, 
                                                all_prompt = prompts_path,    
                                                # local_seed = new_seed,
                                                mask_path = mask_paths[0], 
                                                seeds = SEED, 
                                                )



            concatenate_video_1st_frames(cur_gen_dir=cur_gen_dir, video_paths=file_names, output_path='whole_' + suffix + '.png')


            print(f'---------------- 7. Scoring + Selection ----------------') 
            filtered_videos = sorted([f for f in os.listdir(cur_gen_dir) if f.endswith('.mp4')])
            
            if args.selection_score == 'dsg_blip' : 

                candidates_dsg_scores = automatic_scoring_w_dsg(filtered_videos, cur_gen_dir, qid2question, init_prompt, qid2dependency)
                max_value = max(candidates_dsg_scores)
                max_indices = [i for i, score in enumerate(candidates_dsg_scores) if score == max_value]

                # select using blip 
                if len(max_indices) == 1 : 
                    max_index = max_indices[0]
                else : 
                    max_clip_score = 0 ; max_index = None 
                    for idx in max_indices : 
                        cur_video_path = os.path.join(cur_gen_dir, filtered_videos[idx])
                        score = calculate_blip_bleu(cur_video_path, init_prompt, blip2_model, blip2_processor)
                        print(f'{idx} video blip score: {score}')
                        if score > max_clip_score : 
                            max_clip_score = score
                            max_index = idx 

            elif args.selection_score == 'blip_only' : 
                blip_scores = []
                for cur_video_path in filtered_videos : 
                    score = calculate_blip_bleu(os.path.join(cur_gen_dir, cur_video_path), init_prompt, blip2_model, blip2_processor)
                    blip_scores.append(score)
                max_value = max(blip_scores)
                max_index = [i for i, score in enumerate(blip_scores) if score == max_value][0]

    
            max_score_file = filtered_videos[max_index]
            full_max_score_file = os.path.join(cur_gen_dir, max_score_file)
            print(full_max_score_file)

            # copy next round dir 
            next_cur_gen_dir = cur_gen_dir.replace(str(cur_round) + '_round', str(cur_round+1) + '_round')
            os.makedirs(next_cur_gen_dir, exist_ok=True) 
            shutil.copy(full_max_score_file, os.path.join(next_cur_gen_dir, max_score_file  ))
            shutil.copy(full_max_score_file, os.path.join(cur_gen_best_dir, str(cur_round) + '_best_video.mp4'   ))

            is_best_video = True 

        cur_end_time = time.time()
        cur_time_seconds = cur_end_time - cur_start_time
        cur_times_zip.append(cur_time_seconds)
        print('Cur_time_second: ', cur_time_seconds)

        with open(os.path.join(cur_result_root, 'ours_timeconsump.txt'), "w") as file:
            for item in cur_times_zip:
                file.write(f"{item}\n")  


        # write dsg 
        with open(os.path.join(cur_gen_best_dir, 'dsg_log.txt'), 'w') as file: 
            for d in DSG_logs : 
                file.write(f"{d}\n")


        # break 
    
    print('Process are done!')

    end_time = time.time()
    total_time_seconds = end_time - start_time
    total_time_minutes = total_time_seconds / 60
    hours = int(total_time_minutes // 60) ; minutes = int(total_time_minutes % 60)
    print(f"Total Execution time: {hours} h {minutes} min")
