import sys
import os
import copy
import torch
import argparse
import warnings
import numpy as np
from PIL import Image
import requests
from decord import VideoReader, cpu

# Suppress warnings
warnings.filterwarnings("ignore")

# Set CUDA environment variables
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7,8,9"

# Construct path to parent directory
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.join(script_dir, os.path.pardir)
# llava_path = os.path.join(parent_dir, "LongLLaVA")
sys.path.append(parent_dir)

# Import LLaVA modules
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle

def main(args):
    # Load the pretrained model
    tokenizer, model, image_processor, max_length = load_pretrained_model(
        args.pretrained, None, "llava_qwen", device_map="auto"
    )

    model.eval()

    # Load video and process frames
    vr = VideoReader(args.video_path, ctx=cpu(0))
    total_frame_num = len(vr)
    uniform_sampled_frames = np.linspace(0, total_frame_num - 1, args.max_frames_num, dtype=int)
    frame_idx = uniform_sampled_frames.tolist()
    frames = vr.get_batch(frame_idx).asnumpy()
    video_tensor = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to("cuda", dtype=torch.float16)

    # Prepare conversation
    conv_template = "qwen_1_5"
    question = DEFAULT_IMAGE_TOKEN + "\nDescribe this video."
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to("cuda:0")

    # Generate text
    cont = model.generate(
        input_ids,
        images=[video_tensor],
        modalities=["video"],
        do_sample=False,
        temperature=0,
        max_new_tokens=4096,
        use_cache=False,
    )
    text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
    print(text_outputs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process video with LLaVA model.")
    parser.add_argument("--video_path", type=str, required=True, help="Path to the video file.")
    parser.add_argument("--max_frames_num", type=int, required=True, help="Maximum number of frames to process.")
    parser.add_argument("--pretrained", type=str, required=True, help="Path to the pretrained model.")

    args = parser.parse_args()
    main(args)
