import argparse, os, sys, glob, shutil, time 
import warnings;warnings.filterwarnings('ignore')
import sys
sys.path.append("..")
# from segment_anything import sam_model_registry, SamPredictor
import torch, torchvision
import numpy as np
from omegaconf import OmegaConf
from PIL import Image, ImageDraw, ImageFont
import subprocess
import re, io
import json 
import time 
import cv2, random 
import base64
from tqdm import tqdm, trange
from pprint import pprint
from copy import deepcopy
import openai 
from openai import AzureOpenAI
from dotenv import load_dotenv
import matplotlib.pyplot as plt
import cv2
from functools import reduce
import matplotlib.pyplot as plt 
from lightning_fabric import seed_everything
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.bleu.bleu import Bleu

# videocrafter2 
from funcs import load_model_checkpoint, load_prompts, load_image_batch, get_filelist, save_videos
from funcs import batch_ddim_sampling
from utils.utils import instantiate_from_config


client = AzureOpenAI(
            azure_endpoint = #  your key,
            api_key= #  your key,
            api_version=#  your key,
            )

def save_video(
    vid_tensor, metadata: dict, root_path="./", fps=16
):
    unique_name = metadata['prompt'].replace('\n', '') + '_' + str(metadata['idx']) + '.mp4'     
    print('unique_name :', unique_name)
    unique_name = os.path.join(root_path, unique_name)      
    print('Save path : ', unique_name)

    video = vid_tensor.detach().cpu()
    video = torch.clamp(video.float(), -1.0, 1.0)
    video = video.permute(1, 0, 2, 3)  # t,c,h,w
    video = (video + 1.0) / 2.0
    video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1)

    torchvision.io.write_video(
        unique_name, video, fps=fps, video_codec="h264", options={"crf": "10"}
    )
    return unique_name

def save_videos(
    video_array, metadata: dict, fps: int = 16
):
    paths = []
    # root_path = "./videos/"
    root_path = metadata["save_path"]         
    os.makedirs(root_path, exist_ok=True)
    with ThreadPoolExecutor() as executor:
        paths = list(
            executor.map(
                save_video,
                video_array,
                # [profile] * len(video_array),
                [metadata] * len(video_array),
                [root_path] * len(video_array),
                [fps] * len(video_array),
            )
        )
    return paths[0]


def T2VTurbo_from_each_prompt_wo_bottleneck_onebyone(outpath,          # generation path 
                                                    prompt_text,       # count or action? 
                                                    seeds,
                                                    pipeline, 
                                                    ):  
        os.makedirs(outpath, exist_ok=True)
        res_dir = outpath

        torch.manual_seed(seeds)
        seed_everything(seeds)    

        start_time = time.time()

        result = pipeline(
            prompt=prompt_text,
            frames=16,
            fps=16,
            guidance_scale=7.5,
            num_inference_steps=4,
            num_videos_per_prompt=1,
        )
        paths = save_videos(
            result,              \
            metadata={
                "prompt": prompt_text,
                "seed": seeds,
                "guidance_scale": 7.5,
                "num_inference_steps": 4,
                "save_path" : res_dir, 
                "idx" : 0,    # NOTE tentative 
            },
            fps=16,
        )

def T2VTurbo_from_each_prompt_MultiDiffusion(outpath,           # generation path 
                                            round_num, 
                                            noise_map, 
                                            prompt_text,        # count or action? 
                                            seeds,
                                            pipeline, 
                                            all_prompt,       # added 
                                            local_seed, 
                                            mask_path,          # collection of masks 
                                            suffix, 
                                            ):  
        os.makedirs(outpath, exist_ok=True)
        res_dir = outpath
        torch.manual_seed(seeds)
        seed_everything(seeds)    

        print('Re-randomization seed: ', local_seed)

        start_time = time.time()
        result = pipeline(
            prompt=prompt_text,
            all_prompt=all_prompt,    # added 
            mask_path=mask_path, 
            round_num = round_num, 
            noise_map = noise_map, 
            frames=16,
            fps=16,
            guidance_scale=7.5,
            num_inference_steps=4,
            num_videos_per_prompt=1,
            origin_seed = seeds, 
            local_seed = local_seed, 
        )
        paths = save_videos(
            result,              \
            metadata={
                "prompt": prompt_text,
                "seed": seeds,
                "guidance_scale": 7.5,
                "num_inference_steps": 4,
                "save_path" : res_dir, 
                "idx" : suffix,    # NOTE tentative 
            },
            fps=16,
        )

def VideoCrafter_from_each_prompt(
                            # config, model, device, sampler, 
                            outpath,          # generation path 
                            prompts_path, 
                            seeds = 123):  
        os.makedirs(outpath, exist_ok=True)

        # Base parameters for VideoCrafter2
        ckpt = '/checkpoints/base_512_v2/model.ckpt'       # fix 
        config = '/configs/inference_t2v_512_v2.0.yaml'    # fix 
        SEED = str(seeds)
        prompt_file = prompts_path


        # aruments 
        command = [
            "python3", "VideoCrafter/scripts/evaluation/inference.py",
            "--seed", SEED,
            "--mode", "base",
            "--ckpt_path", ckpt,
            "--config", config,
            "--savedir", outpath,
            "--n_samples", "1",
            "--bs", "1",
            "--height", "320",
            "--width", "512",
            "--unconditional_guidance_scale", "12.0",
            "--ddim_steps", "50",
            "--ddim_eta", "1.0",
            "--prompt_file", prompt_file,
            "--fps", "28",
        ]
        print(command)

        result = subprocess.run(command, capture_output=True, text=True)
        if result.returncode == 0:
            print("Command executed successfully with output:")
            print(result.stdout)
        else:
            print("Command execution failed with error:")
            print(result.stderr)

def VideoCrafter_Multidiffusion(
                            # config, model, device, sampler, 
                            outpath,          # generation path 
                            all_prompt,
                            # local_seed,
                            mask_path, 
                            seeds):  
        os.makedirs(outpath, exist_ok=True)

        # Base parameters for VideoCrafter2
        ckpt = 'checkpoints/base_512_v2/model.ckpt'       # fix 
        config = 'configs/inference_t2v_512_v2.0.yaml'    # fix 
        SEED = str(seeds)

        # aruments 
        command = [
            "python3", "scripts/evaluation/inference_multidiffusion.py",
            "--seed", SEED,
            "--mode", "base",
            "--ckpt_path", ckpt,
            "--config", config,
            "--savedir", outpath,
            "--n_samples", "1",
            "--bs", "2",
            "--height", "320",
            "--width", "512",
            "--unconditional_guidance_scale", "12.0",
            "--ddim_steps", "50",
            "--ddim_eta", "1.0",
            "--prompt_file", all_prompt,        # all prompt path 
            "--fps", "28", 
            "--mask_path", mask_path,         # NOTE added
            # "--local_seed", str(local_seed),       # NOTE added
        ]
        print(command)

        result = subprocess.run(command, capture_output=True, text=True)
        if result.returncode == 0:
            print("Command executed successfully with output:")
            print(result.stdout)
        else:
            print("Command execution failed with error:")
            print(result.stderr)


def encode_video(video_path) : 
    '''
    Input : video path 
    Output : video by each frame 
    '''

    video = cv2.VideoCapture(video_path)
    base64Frames = []
    while video.isOpened():
        success, frame = video.read()
        if not success:
            break
        _, buffer = cv2.imencode(".jpg", frame)
        base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
    video.release()
    print(len(base64Frames), "frames read.")
    return base64Frames


def extract_first_frame(video_path, cur_gen_dir) : 
    base64Frames = encode_video(video_path)[0]          
    first_frame_img = os.path.join(cur_gen_dir, 'first_frame.jpg')
    img_data = base64.b64decode(base64Frames)
    np_arr = np.frombuffer(img_data, np.uint8)
    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    cv2.imwrite(first_frame_img, img)

    img = Image.open(first_frame_img)       
    return img 

def encode_gpt4_input(pil_image_object) : 
    buffered = io.BytesIO() 
    pil_image_object.save(buffered, format="JPEG")
    base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") 
    return base64_image


def asking_gpt4o(system_prompt, task_prompt, gpt4_input_image) : 
    response = client.chat.completions.create( 
                                model="gpt-4o",           # "gpt-4o-new"
                                messages=[
                                    {"role": "system", "content": system_prompt},
                                    {
                                        "role": "user",
                                        "content": [
                                            {"type": "text", "text":  task_prompt}, 
                                            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{gpt4_input_image}"}},
                                        ],
                                    }
                                ],
                                max_tokens=100,
                            )
    answer = response.choices[0].message.content  
    return answer 


def filter_DSG_answer_w_dependency(dsg_answers, qid2dependency) : 
    qid2scores = {} ; qid2validity = {}
    
    # stack question id + answer 
    for idx, qa in enumerate(dsg_answers) : 
        qid2scores[str(idx+1)] = qa['A']            # e.g., {'1': 0.0, '2': 0.0, '3': 1.0, '4': 1.0}

    # consider dependency -> modify dsg_answers 
    for id, parent_ids in qid2dependency.items() : 
        any_parent_answered_no = False

        for parent_id in parent_ids:
            if parent_id == 0:               # no dependency 
                continue 
            if qid2scores[str(parent_id)] == 0:
                any_parent_answered_no = True 
                break 
        
        if any_parent_answered_no : 
            qid2scores[id] = 0.0  
            try : 
                dsg_answers[int(id)-1]['A'] = 0.0        # dsg answer updated 
            except : 
                continue            # TODO dependency 1:1 matching bug
            qid2validity[id] = False                 # changed because of parent questions 
        else :  
            qid2validity[id] = True                  # True: there are no problem of dependency 
    
    return qid2scores, qid2validity, dsg_answers



def point2mask_semanticsam(img_path, point_lists, mask_save_path, img_width=512, img_height=320) : 
    command = [
        "python3", "./SemanticSAM/ssam.py",      # call pyfile 
        "--img_path", img_path,
        "--point_lists", str(point_lists),
        "--mask_save_path", mask_save_path, 
    ]
    print(command)

    result = subprocess.run(command, capture_output=True, text=True)
    if result.returncode == 0:
        print("Command executed successfully with output:")
        print(result.stdout)
    else:
        print("Command execution failed with error:")
        print(result.stderr)


def prompt_generator_from_Q_v4(question_list) : 

    system_prompt_local = (
        "You are an expert in rephrasing prompts for a text-to-video model based on the given questions."
    )

    task_prompt_local = (
        f"Given the following list of questions {question_list}, \
        create a single descriptive sentence that combines the meaning of each question into a natural, affirmative statement that provides a full, concise summary."
        "Your response should be a concise 1 phrase, without additional explanation.  (e.g., 'a small bear')"
        "Examples: "
    )

    examples = """

        - Example 1 
            Question list: ['Is there a bed?', 'Is the bed blue?', 'Are the pillows beige?', 'Are the pillows with the bed?']
            Answer: "Blue bed with beige pillows."

        - Example 2 
            Question list: [Are there three real bears?]
            Answer: "Three real bears."

        - Example 3 
            Question list: [Are there two people?, Are the people making pizza?]
            Answer: "Two people making pizza.

        - Example 4 
            Question list: [Is there a family?, Is there one cat?, Is there a park?, Is the family taking a walk?, Is the cat walking?, Is the family enjoying?, Is the family breathing fresh air?, Is the family exercising?]
            Answer: "A family and a cat are walking in the park."

        - Example 5 
            Question list: [Is there a green bench?, Is there an orange tree?, Is the bench green?, Is the tree orange?]
            Answer: "Green bench and orange tree."

    Your Current Task: Your response should be a concise 1 phrase, without additional explanation (e.g., "a small bear")

    """

    stop = False ; error_count = 0 
    while not stop:
        try : 
            local_response = client.chat.completions.create( 
                                            model="gpt-4-0125",
                                            # model="gpt-4-32k", 
                                            messages=[
                                                {"role": "system", "content": system_prompt_local},
                                                {
                                                    "role": "user",
                                                    "content": [
                                                        {"type": "text", "text":  task_prompt_local + examples}, 
                                                    ],
                                                }
                                            ],
                                            max_tokens=100,
                                        )
            local_prompt_answer = local_response.choices[0].message.content  
            stop = True 

        except : 
            print('ERROR..')
            time.sleep(9)
            error_count += 1 
            if error_count > 3 :           # NOTE if error count > 3 : stop 
                local_prompt_answer = None 
                stop = True   
    return local_prompt_answer


def ask_gpt4o_DSG_and_grounding_wo_vprompt(gpt4_input_image, qid2question, init_prompt) : 
    '''
    Args: 
        - gpt4_input_image: first frame of video w/ scaffolding image 
        - qid2question: DSG questions 
    '''
    dsg_answers_with_area = []

    # NOTE ask each questions for only first frame 
    for i in range(len(qid2question)) : 
        cur_question = qid2question[str(i+1)]               # One-by-one 
        # cur_question = pre + str(qid2question)            # Altogether  

        system_prompt = f'You are an expert at answering questions about the content of a given image.'

        task_prompt = f'1. Given the question: "{cur_question}", provide a brief reasoning (up to two sentences) to determine an accurate answer. \
                        2. Respond using binary values: 1.0 for Yes and 0.0 for No. If the answer is uncertain due to image distortion or other issues, respond with 0.0 (No). \
                        Return the result as a dictionary in the following format (not in JSON format): \
                        {{"Q": "<question>", "reasoning": "<brief reasoning>", "A": <binary answer>}} \
                        (e.g., {{"Q": "Is there one robot?", "reasoning": "There are two visible robots in the image. To guarantee a Yes answer, one robot should be removed.", "A": 0.0}}) \
                        Provide only the dictionary as the output, without any additional text or explanations.'

        success = False ; error_count = 0 
        while not success:
            try : 
                response = client.chat.completions.create( 
                                                model="gpt-4o",           # "gpt-4o-new"
                                                messages=[
                                                    {"role": "system", "content": system_prompt},
                                                   {
                                                        "role": "user",
                                                        "content": [
                                                            {"type": "text", "text":  task_prompt}, 
                                                            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{gpt4_input_image}"}},
                                                        ],
                                                    }
                                                ],
                                                max_tokens=100,
                                            )

                # response = client.chat.completions.create( model="gpt-4o-new", messages=[{"role": "system", "content": system_prompt},{"role": "user","content": [{"type": "text", "text":  task_prompt}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{gpt4_input_image}"}},],}],max_tokens=3000)

                answer = response.choices[0].message.content  
                print(answer)
                print('*' * 5)
                success = True 

            except : 
                print('ERROR..')
                time.sleep(9)
                error_count += 1 
                if error_count > 3 :           # NOTE if error count > 3 : stop 
                    return  

        # NOTE post-processing for receive output 
        try : 
            answer = answer.replace('"', '\\"')
            answer_dict = eval(answer.replace('\n', '').replace('```json', '').replace('```', '').replace('\\', '').replace('```python', ''))
        except : 
            # import pdb;pdb.set_trace()
            return 

        dsg_answers_with_area.append(answer_dict)

    return dsg_answers_with_area


## NOTE all-in-one 
def ask_gpt4o_DSG_and_grounding_wo_vprompt(gpt4_input_image, qid2question) : 
    '''
    Args: 
        - gpt4_input_image: first frame of video w/ scaffolding image 
        - qid2question: DSG questions 
    '''
    dsg_answers_with_area = []

    # NOTE ask each questions for only first frame 
    for i in range(len(qid2question)) : 
        cur_question = qid2question[str(i+1)]               # One-by-one 
        # cur_question = pre + str(qid2question)            # Altogether  

        system_prompt = f'You are an expert at answering questions about the content of a given image.'

        task_prompt = f'1. Given the question: "{cur_question}", provide a brief reasoning (up to two sentences) to determine an accurate answer. \
                        2. Respond using binary values: 1.0 for Yes and 0.0 for No. If the answer is uncertain due to image distortion or other issues, respond with 0.0 (No). \
                        Return the result as a dictionary in the following format (not in JSON format): \
                        {{"Q": "<question>", "reasoning": "<brief reasoning>", "A": <binary answer>}} \
                        (e.g., {{"Q": "Is there one robot?", "reasoning": "There are two visible robots in the image. To guarantee a Yes answer, one robot should be removed.", "A": 0.0}}) \
                        Provide only the dictionary as the output, without any additional text or explanations.'

        success = False ; error_count = 0 
        while not success:
            try : 
                response = client.chat.completions.create( 
                                                model="gpt-4o",           # "gpt-4o-new"
                                                messages=[
                                                    {"role": "system", "content": system_prompt},
                                                   {
                                                        "role": "user",
                                                        "content": [
                                                            {"type": "text", "text":  task_prompt}, 
                                                            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{gpt4_input_image}"}},
                                                        ],
                                                    }
                                                ],
                                                max_tokens=100,
                                            )

                # response = client.chat.completions.create( model="gpt-4o-new", messages=[{"role": "system", "content": system_prompt},{"role": "user","content": [{"type": "text", "text":  task_prompt}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{gpt4_input_image}"}},],}],max_tokens=3000)

                answer = response.choices[0].message.content  
                print(answer)
                print('*' * 5)
                success = True 

            except : 
                print('ERROR..')
                # import pdb;pdb.set_trace()
                time.sleep(9)
                error_count += 1 
                if error_count > 3 :           # NOTE if error count > 3 : stop 
                    return  

        # NOTE post-processing for receive output 
        try : 
            answer = answer.replace('"', '\\"')
            answer_dict = eval(answer.replace('\n', '').replace('```json', '').replace('```', '').replace('\\', '').replace('```python', ''))
        except : 
            # import pdb;pdb.set_trace()
            return 

        dsg_answers_with_area.append(answer_dict)

    return dsg_answers_with_area



def ask_molmo(processor, molmo, PIL_input_img, input_prompt, viz_path, ori_w=512, ori_h=320) : 
    # process the image and text
    inputs = processor.process(images=[PIL_input_img], text=input_prompt,)
    original_height, original_width = ori_h, ori_w

    # move inputs to the correct device and make a batch of size 1
    inputs = {k: v.to(molmo.device).unsqueeze(0) for k, v in inputs.items()}

    # generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
    output = molmo.generate_from_batch(
                                    inputs,
                                    GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
                                    tokenizer=processor.tokenizer
                                )

    # only get generated tokens; decode them to text
    generated_tokens = output[0,inputs['input_ids'].size(1):]
    string = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
    print(string)
 
    try :            # if point is >2
        x = float(re.search(r'x="([0-9.]+)"', string).group(1))  
        y = float(re.search(r'y="([0-9.]+)"', string).group(1)) 
        points_list = [[x, y]]
    except :         # if point is 1
        coordinates = re.findall(r'(x\d+)="([\d.]+)" (y\d+)="([\d.]+)"', string)
        points_list = [[float(x_value), float(y_value)] for _, x_value, _, y_value in coordinates]

    # transform to pixel axis 
    pixel_points = []
    for output_coordinates in points_list : 
        X_pixel = (original_width * output_coordinates[0]) / 100
        Y_pixel = (original_height * output_coordinates[1]) / 100   
        pixel_points.append([X_pixel, Y_pixel])

    # visualize 
    background = np.array(PIL_input_img)
    plt.figure(figsize=(10, 6))
    plt.imshow(background)
    for point in pixel_points:
        plt.scatter(point[0], point[1], color='white', s=100, edgecolor='blue', linewidths=2)  
        plt.text(point[0] + 5, point[1] - 5, f'({point[0]:.2f}, {point[1]:.2f})', color='white', fontsize=12)
    plt.axis('off') 
    plt.savefig(viz_path, format='png', bbox_inches='tight', pad_inches=0)
    plt.show()
    return pixel_points


def paraphrasing_prompt(origin_prompt) : 
        task_prompt_key = (
            f"Given the prompt: {origin_prompt}, generate 1 paraphrases of the initial prompt which keep the semantic meaning."
            "Respond with each new prompt in between <PROMPT> and </PROMPT>, eg: <PROMPT>paraphrase </PROMPT>. Answer using a single phrase. Do NOT generate any explanation, write only answer."
        )

        stop = False ; error_count = 0 
        while not stop:
            try : 
                new_prompt = client.chat.completions.create( 
                                                model="gpt-4-0125",
                                                messages=[
                                                    {
                                                        "role": "user",
                                                        "content": [
                                                            {"type": "text", "text":  task_prompt_key}, 
                                                        ],
                                                    }
                                                ],
                                                max_tokens=100,
                                            )
                new_prompt = new_prompt.choices[0].message.content  
                new_prompt = re.findall(r'<PROMPT>(.*?)</PROMPT>', new_prompt)
                stop = True 

            except : 
                print('ERROR..')
                import pdb;pdb.set_trace()
                time.sleep(9)
                error_count += 1 
                if error_count > 3 :           # NOTE if error count > 3 : stop 
                    local_prompt_answer = None 
                    stop = True 

            return new_prompt  

def automatic_scoring_w_dsg(videos, cur_gen_dir, qid2question, init_prompt, qid2dependency): 
    '''
    - Input -> videos (video path list) 
    - Output -> score 
    '''
    all_dsg_scores = []

    ## Candidate video evaluation 
    for video_path in videos : 
        
        first_frame_img = extract_first_frame(os.path.join(cur_gen_dir, video_path), cur_gen_dir)  
        first_frame_img_gpt = encode_gpt4_input(first_frame_img)

        dsg_answers = ask_gpt4o_DSG_and_grounding_wo_vprompt(first_frame_img_gpt, qid2question, init_prompt)

        # NOTE DSG dependency consider 
        qid2scores = {} ; qid2validity = {}

        # stack question id + answer 
        if dsg_answers == None : 
            continue 

        try : 
            for idx, qa in enumerate(dsg_answers) : 
                qid2scores[str(idx+1)] = qa['A']            # e.g., {'1': 0.0, '2': 0.0, '3': 1.0, '4': 1.0}

            # consider dependency -> modify dsg_answers 
                for id, parent_ids in qid2dependency.items() : 
                    any_parent_answered_no = False

                    for parent_id in parent_ids:
                        if parent_id == 0:               # no dependency 
                            continue 
                        if qid2scores[str(parent_id)] == 0:
                            any_parent_answered_no = True 
                            break 
                    
                    if any_parent_answered_no : 
                        qid2scores[id] = 0.0  
                        try : 
                            dsg_answers[int(id)-1]['A'] = 0.0        # dsg answer updated 
                        except : 
                            continue            # TODO dependency 1:1 matching bug
                        qid2validity[id] = False                 # changed because of parent questions 
                    else :  
                        qid2validity[id] = True                  # True: there are no problem of dependency 
        except : 
            dsg_answers = dsg_answers

        dsg_score = sum(float(qa['A']) for qa in dsg_answers) / len(dsg_answers)
        all_dsg_scores.append(dsg_score)      # candidate video score 

        print('=' * 50)
        print('Video path: ', video_path)
        print('DSG score: ', dsg_score)
        print('=' * 50)


    return all_dsg_scores       

def compute_max(scorer, gt_prompts, pred_prompts):
    scores = []
    for pred_prompt in pred_prompts:
        for gt_prompt in gt_prompts:
            cand = {0: [pred_prompt]}
            ref = {0: [gt_prompt]}
            score, _ = scorer.compute_score(ref, cand)
            scores.append(score)
    return np.max(scores)

    

def calculate_blip_bleu(video_path, original_text, blip2_model, blip2_processor):
    # Load the video
    cap = cv2.VideoCapture(video_path)

    scorer_cider = Cider()
    bleu1 = Bleu(n=1)
    bleu2 = Bleu(n=2)
    bleu3 = Bleu(n=3)
    bleu4 = Bleu(n=4)

    # Extract frames from the video
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        resized_frame = cv2.resize(frame,(224,224))  # Resize the frame to match the expected input size
        frames.append(resized_frame)

    # Convert numpy arrays to tensors, change dtype to float, and resize frames
    tensor_frames = torch.stack([torch.from_numpy(frame).permute(2, 0, 1).float() for frame in frames])
    # Get five captions for one video
    Num = 5
    captions = []
    # for i in range(Num):
    N = len(tensor_frames)
    indices = torch.linspace(0, N - 1, Num).long()
    extracted_frames = torch.index_select(tensor_frames, 0, indices)
    for i in range(Num):
        frame = extracted_frames[i]
        inputs = blip2_processor(images=frame, return_tensors="pt").to('cuda', torch.float16)
        generated_ids = blip2_model.generate(**inputs)
        generated_text = blip2_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        captions.append(generated_text)


    original_text = [original_text]
    cider_score = (compute_max(scorer_cider, original_text, captions))
    bleu1_score = (compute_max(bleu1, original_text, captions))
    bleu2_score = (compute_max(bleu2, original_text, captions))
    bleu3_score = (compute_max(bleu3, original_text, captions))
    bleu4_score = (compute_max(bleu4, original_text, captions))

    blip_bleu_caps_avg = (bleu1_score + bleu2_score + bleu3_score + bleu4_score)/4
     
    return blip_bleu_caps_avg



def concatenate_video_1st_frames(cur_gen_dir, video_paths, output_path):
    frames = []

    for video_path in video_paths:
        cap = cv2.VideoCapture(os.path.join(cur_gen_dir, video_path+'.mp4'))
        ret, frame = cap.read()  
        if ret:
            frames.append(frame) 
        cap.release()  

    frame_heights = [frame.shape[0] for frame in frames]
    max_height = max(frame_heights)

    padded_frames = []
    for frame in frames:
        h, w, _ = frame.shape
        if h < max_height:  
            padding = np.zeros((max_height - h, w, 3), dtype=np.uint8)
            frame = np.vstack((frame, padding))
        padded_frames.append(frame)

    concatenated_image = np.hstack(padded_frames)
    cv2.imwrite(os.path.join(cur_gen_dir, output_path), concatenated_image)
    print('Saved 1st frames to ' + os.path.join(cur_gen_dir, output_path))