from operator import attrgetter
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle

import torch
import cv2
import numpy as np
from PIL import Image
import requests
import copy
import warnings
from decord import VideoReader, cpu

warnings.filterwarnings("ignore")
# Load the OneVision model
pretrained = "/path/to/llava-onevision-qwen2-7b-ov"
model_name = "llava_qwen"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")

model.eval()


# Function to extract frames from video
def load_video(video_path, max_frames_num):
    if type(video_path) == str:
        vr = VideoReader(video_path, ctx=cpu(0))
    else:
        vr = VideoReader(video_path[0], ctx=cpu(0))
    total_frame_num = len(vr)
    uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
    frame_idx = uniform_sampled_frames.tolist()
    spare_frames = vr.get_batch(frame_idx).asnumpy()
    return spare_frames  # (frames, height, width, channels)


criteria = [
    '''
    Textual Faithfulness
    This measures the degree to which the edited video aligns with the text description provided for editing.
    1: The edited video completely misaligns with the text description.
    2: The edited video mostly misaligns with the text description.
    3: The edited video generally aligns with the text description, but many details are missing.
    4: The edited video aligns with the text description in most aspects, with only a few details not reflected.
    5: The edited video fully aligns with the text description, capturing all details accurately.
    ''',
    '''
    Frame Consistency
    This assesses the continuity between adjacent frames in the edited video.
    1: There is no continuity between frames, resulting in a poor viewing experience.
    2: The continuity between frames is poor, with noticeable jumps.
    3: The continuity between frames is average, with minor jumps in some scenes.
    4: The continuity between frames is good, with only minimal jumps in a very few scenes.
    5: The frames flow smoothly and continuously without any noticeable jumps.
    ''',
    '''
    Video Fidelity
    This evaluates the realism of the edited video, including factors such as color accuracy, overall visual quality, and viewer experience.
    1: The video suffers from severe color distortion, poor visual quality, and weak overall presentation, leading to a very poor viewing experience.
    2: The video has significant color distortion and overall visual quality issues, with noticeable inconsistencies.
    3: The video has slight color distortion and is generally acceptable, but some unnatural elements are still noticeable.
    4: The video is close to realistic, with good overall quality and only minor imperfections in rare instances.
    5: The video is fully realistic, with excellent visual quality and no noticeable flaws, providing a perfect viewing experience.
    '''
]
result_keys = [
    'Textual_Faithfulness',
    'Frame_Consistency',
    'Video_Fidelity'
]

text_input = '''
You are given a video that has been edited by a video editing model, alongside its corresponding text condition and the description of the original video. Your task is to watch the video and evaluate it on a scale from 1 to 5 according to the scoring criteria provided below. After generating the score, provide a brief explanation of your reasoning. Answer in the format:

"The score is {generated_score}. Reason: {explanation of why this score was given}."

Criteria: 
###criteria###

Origin Video:
###org_caption###

Video editing text condition:
###text_prompt###
'''

import os
import pandas as pd
import glob
from tqdm import tqdm
import json

df = pd.read_csv('../labeled_full.csv')
captions = list(df['caption'])
prompts = list(df['editing_prompt'])
entities = list(df['editing_entity'])
keys = df['key']
video_dirs_v1 = list(df['Edited Video'])
video_root_path = '/path/to/video/editing/models/inference/results'
video_dirs_v1 = [os.path.join(video_root_path, f"{item.split('/')[-1]}.mp4") for item in video_dirs_v1]
video_dirs_v2 = list(df)
frame_counts = list(df['frames'])

results = []
output_dir = 'LLaVA-OneVision-7B'
single_output_dir = os.path.join(output_dir, "single")
os.makedirs(single_output_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
for i in tqdm(range(len(video_dirs_v1))):
    cur_res = {}
    cur_file = video_dirs_v1[i]
    cur_answers = []
    video_path = cur_file
    video_frames = load_video(video_path, int(frame_counts[i]))
    image_tensors = []
    frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()
    image_tensors.append(frames)
    for j in range(len(criteria)):
        cur_text = text_input.strip().replace('###text_prompt###', prompts[i]).replace('###org_caption###', captions[i]).replace('###criteria###', criteria[j])
        sample_set = {}
        question = cur_text
        conv_template = "qwen_1_5"
        question = f"{DEFAULT_IMAGE_TOKEN}\n{question}"

        conv = copy.deepcopy(conv_templates[conv_template])
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        prompt_question = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
        image_sizes = [frame.size for frame in video_frames]

        cont = model.generate(
            input_ids,
            images=image_tensors,
            image_sizes=image_sizes,
            do_sample=False,
            temperature=0,
            max_new_tokens=4096,
            modalities=["video"],
        )
        text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)

        cur_res[result_keys[j]] = text_outputs[0]
    single_output_file_path = os.path.join(single_output_dir, f"{str(i).zfill(6)}.json")
    with open(single_output_file_path, "w", encoding="utf-8") as f:
        json.dump(cur_res, f, indent=4)
    results.append(cur_res)

full_output_file_path = os.path.join(output_dir, "full.json")
with open(full_output_file_path, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=4)

