import torch
import transformers

import sys

sys.path.append('./')
from videollama2.conversation import conv_templates
from videollama2.constants import DEFAULT_MMODAL_TOKEN, MMODAL_TOKEN_INDEX
from videollama2.mm_utils import get_model_name_from_path, tokenizer_MMODAL_token, process_video, process_image
from videollama2.model.builder import load_pretrained_model

start_index = int(sys.argv[1])
end_index = int(sys.argv[2])

# paths = ['assets/sora.png']
# questions = ['What is the woman wearing, what is she doing, and how does the image feel?']
# modal_list = ['image']

model_path = '/Path/to/VideoLLaMA2-7B-16F'
model_name = get_model_name_from_path(model_path)
tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name)
model = model.to('cuda:0')
conv_mode = 'llama2'

criteria = [
    '''
    Textual Faithfulness
    This measures the degree to which the edited video aligns with the text description provided for editing.
    1: The edited video completely misaligns with the text description.
    2: The edited video mostly misaligns with the text description.
    3: The edited video generally aligns with the text description, but many details are missing.
    4: The edited video aligns with the text description in most aspects, with only a few details not reflected.
    5: The edited video fully aligns with the text description, capturing all details accurately.
    ''',
    '''
    Frame Consistency
    This assesses the continuity between adjacent frames in the edited video.
    1: There is no continuity between frames, resulting in a poor viewing experience.
    2: The continuity between frames is poor, with noticeable jumps.
    3: The continuity between frames is average, with minor jumps in some scenes.
    4: The continuity between frames is good, with only minimal jumps in a very few scenes.
    5: The frames flow smoothly and continuously without any noticeable jumps.
    ''',
    '''
    Video Fidelity
    This evaluates the realism of the edited video, including factors such as color accuracy, overall visual quality, and viewer experience.
    1: The video suffers from severe color distortion, poor visual quality, and weak overall presentation, leading to a very poor viewing experience.
    2: The video has significant color distortion and overall visual quality issues, with noticeable inconsistencies.
    3: The video has slight color distortion and is generally acceptable, but some unnatural elements are still noticeable.
    4: The video is close to realistic, with good overall quality and only minor imperfections in rare instances.
    5: The video is fully realistic, with excellent visual quality and no noticeable flaws, providing a perfect viewing experience.
    '''
]
result_keys = [
    'Textual_Faithfulness',
    'Frame_Consistency',
    'Video_Fidelity'
]

text_input = '''
You are given a video that has been edited by a video editing model, alongside its corresponding text condition and the description of the original video. Your task is to watch the video and evaluate it on a scale from 1 to 5 according to the scoring criteria provided below. After generating the score, provide a brief explanation of your reasoning. Answer in the format:

"The score is {generated_score}. Reason: {explanation of why this score was given}."

Criteria: 
###criteria###

Origin Video:
###org_caption###

Video editing text condition:
###text_prompt###
'''

import os
import pandas as pd
import glob
from tqdm import tqdm
import json

df = pd.read_csv('../labeled_full.csv')
captions = list(df['caption'])
prompts = list(df['editing_prompt'])
entities = list(df['editing_entity'])
keys = df['key']
save_path = 'output_json'
video_dirs_v1 = list(df['Edited Video'])
video_root_path = '/path/to/video/editing/models/inference/results'
video_dirs_v1 = [os.path.join(video_root_path, f"{item.split('/')[-1]}.mp4") for item in video_dirs_v1]
video_dirs_v2 = list(df)

results = []
output_dir = "VideoLLaMA2"
single_output_dir = os.path.join(output_dir, "single")
os.makedirs(single_output_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
for i in tqdm(range(start_index, end_index + 1)):
    cur_res = {}
    cur_file = video_dirs_v1[i]
    cur_answers = []
    for j in range(len(criteria)):
        cur_text = text_input.strip().replace('###text_prompt###', prompts[i]).replace('###org_caption###', captions[i]).replace('###criteria###', criteria[j])
        paths = [cur_file]
        questions = [cur_text]
        modal_list = ['video']
        if modal_list[0] == 'video':
            tensor = process_video(paths[0], processor, model.config.image_aspect_ratio).to(dtype=torch.float16, device='cuda:0', non_blocking=True)
            default_mm_token = DEFAULT_MMODAL_TOKEN["VIDEO"]
            modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
        else:
            tensor = process_image(paths[0], processor, model.config.image_aspect_ratio)[0].to(dtype=torch.float16, device='cuda:0', non_blocking=True)
            default_mm_token = DEFAULT_MMODAL_TOKEN["IMAGE"]
            modal_token_index = MMODAL_TOKEN_INDEX["IMAGE"]
        tensor = [tensor]
        question = default_mm_token + "\n" + questions[0]
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_token_index, return_tensors='pt').unsqueeze(0).to('cuda:0')

        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images_or_videos=tensor,
                modal_list=modal_list,
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                use_cache=True,
            )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        print(outputs[0])
        cur_res[result_keys[j]] = outputs[0]
    single_output_file_path = os.path.join(single_output_dir, f"{str(i).zfill(6)}.json")
    with open(single_output_file_path, "w", encoding="utf-8") as f:
        json.dump(cur_res, f, indent=4)
    results.append(cur_res)

