import argparse
import pathlib
import torch
import av
import numpy as np
import pickle
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, AutoImageProcessor
from qwen_vl_utils import process_vision_info
from tqdm.auto import tqdm
import os
from transformers import PretrainedConfig
import transformers
from natsort import natsorted
from prompts import ALL_PROMPTS
from huggingface_hub import hf_hub_download
from pathlib import Path
import cv2
from PIL import Image

PretrainedConfig.cache_dir = './'
os.environ['HF_HOME'] = './'
os.environ['HF_DATASETS_CACHE']='./'
HUGGINGFACE_CACHE_DIR = "./"
os.environ['TRANSFORMERS_CACHE'] = "./"
# Model and quantization setup
MODEL_ID = "DAMO-NLP-SG/VideoLLaMA3-7B"
CONFIG_CLASS = transformers.BitsAndBytesConfig
MODEL_CLASS = transformers.AutoModelForCausalLM
PROCESSOR_CLASS = transformers.AutoProcessor

device = "cuda:0"
MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

GENERATE_KWARGS = {
    "do_sample": True,
    "top_p":0.9, "top_k":2,"max_new_tokens":100,
    "max_length": 512,
}

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
    #bnb_4bit_quant_type="nf4"  # Use nf4 quantization type
)
processor = AutoProcessor.from_pretrained(MODEL_ID, cache_dir=HUGGINGFACE_CACHE_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype=torch.bfloat16, cache_dir=HUGGINGFACE_CACHE_DIR, trust_remote_code=True, quantization_config=quantization_config
)
model = model.to(device)

# Helper Functions
def read_video_pyav(container, indices):
    """
    Decode selected frames from a video using PyAV.
    """
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frames(video_path, num_frames=8):
    """
    Uniformly sample `num_frames` from a video file.
    """
    container = av.open(video_path)
    total_frames = container.streams.video[0].frames
    indices = np.arange(0, total_frames, total_frames / 8).astype(int)
    return read_video_pyav(container, indices)

def batchify(data, batch_size):
    """
    Batchify data into chunks of `batch_size`.
    """
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

def batchify(iterable, batch_size):
    """Splits list into batches."""
    return [iterable[i : i + batch_size] for i in range(0, len(iterable), batch_size)]

def extract_frames(video_path, frame_rate=1):
    """
    Extract frames from a video at the specified frame rate.
    :param video_path: Path to the video file
    :param frame_rate: Frames per second
    :return: List of PIL.Image objects
    """
    video_capture = cv2.VideoCapture(str(video_path))
    frames = []
    
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    frame_interval = int(fps / frame_rate) if fps else 1

    success, frame = video_capture.read()
    count = 0
    while success:
        if count % frame_interval == 0:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))
        success, frame = video_capture.read()
        count += 1

    video_capture.release()
    
    return frames if frames else None  # Ensure frames are not empty

# Main processing function
def process_videos(video_paths, batch_size, output_dir, prompt_number, max_tokens=128, top_p=0.9, top_k=2):
    """
    Process videos using Qwen2.5-VL and generate text outputs.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    if prompt_number not in ALL_PROMPTS:
        raise ValueError(f"Invalid prompt_number {prompt_number}. Choose from {list(ALL_PROMPTS.keys())}.")

    prompt = f"USER: <video>\n{ALL_PROMPTS[prompt_number]} ASSISTANT:"
    BUFFER = []

    for batch_num, batch in enumerate(tqdm(batchify(video_paths, batch_size), total=len(video_paths) // batch_size + 1)):
        batch_outputs = []

        for video_path in batch:
            video_path_str = str(video_path)
            print(f"Processing: {video_path_str}")
            absolute_path = os.path.abspath(video_path_str)
            video_path_formatted = f"file://{absolute_path}"

            # Format input for Qwen-VL
            conversation = [
            {"role": "system", "content": "You are a helpful assistant."},
            {
                    "role": "user",
                    "content": [
                        {"type": "video", "video": {"video_path": video_path_formatted, "fps": 1, "max_frames": 128}},  # ✅ Use the actual video file
                        {"type": "text", "text": ALL_PROMPTS[prompt_number]}
                    ]
                }
            ]

            inputs = processor(conversation=conversation, add_system_prompt=True, add_generation_prompt=True,return_tensors="pt").to(device)
            inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
            if "pixel_values" in inputs:
                inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)

            generate_kwargs = {
                "max_new_tokens": max_tokens,
                "do_sample": True,
                "top_p": top_p,
                "top_k": top_k
            }

            # Generate text
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,  
                    **generate_kwargs,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                )

            generated_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
            print(f"Generated Output: {generated_text}")

            # Extract hidden states
            hidden_states = outputs.hidden_states

            if hidden_states:
                hidden_states_to_save = np.array([
                    [
                        lhs.cpu().float().numpy()
                        for lhs in hidden_states[token_num]
                    ]
                    for token_num in np.arange(1, len(hidden_states))
                ])
                hidden_states_to_save = np.average(hidden_states_to_save[:, :, 0, 0, :], axis=0)
            else:
                hidden_states_to_save = None

            batch_outputs.append({
                "video_path": str(video_path),
                "generated_text": generated_text,
                "language_hidden_states": hidden_states_to_save
            })

        BUFFER.extend(batch_outputs)

        # Save buffer to file after every batch
        batch_file = output_dir / f"batch_{batch_num + 1}.pkl"
        with open(batch_file, "wb") as f:
            pickle.dump(BUFFER, f)

        print(f"✅ Saved batch {batch_num + 1} results to {batch_file}")
        BUFFER = []

    print("✅ Video processing complete!")

# Command-line interface
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Video Processing with VideoLLaMA3-7B Model")
    parser.add_argument("-v","--video-dir", required=True, type=pathlib.Path, help="Directory containing video files")
    parser.add_argument("-b","--batch-size", required=True, type=int, help="Batch size for video processing")
    parser.add_argument("-d","--output-dir", required=True, type=pathlib.Path, help="Directory to save outputs")
    parser.add_argument("-p","--prompt-number", required=True, type=int, help="Prompt number to use")
    parser.add_argument("--max-tokens", default=100, type=int, help="Maximum tokens to generate")
    parser.add_argument("--top-p", default=0.9, type=float, help="Top-p sampling for text generation")
    parser.add_argument("--top-k", default=2, type=int, help="Top-k sampling for text generation")
    args = parser.parse_args()

    # Prepare inputs
    video_dir = args.video_dir
    batch_size = args.batch_size
    output_dir = args.output_dir
    prompt_number = args.prompt_number
    max_tokens = args.max_tokens
    top_p = args.top_p
    top_k = args.top_k

    # Get video paths
    video_paths = natsorted(list(video_dir.glob("*.mp4")))
    print(video_paths[0])

    OUTPUT_DIR = output_dir.joinpath(f"prompt_{args.prompt_number}", MODEL_NAME)

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Process videos
    process_videos(video_paths, batch_size, OUTPUT_DIR, prompt_number, max_tokens, top_p, top_k)
