import argparse
import pathlib
import torch
import av
import numpy as np
import pickle
from transformers import LlavaOnevisionProcessor,AutoProcessor,LlavaOnevisionForConditionalGeneration, LlavaNextVideoForConditionalGeneration, BitsAndBytesConfig
from tqdm.auto import tqdm
import os
from transformers import PretrainedConfig
import transformers
from natsort import natsorted
from prompts import ALL_PROMPTS
from huggingface_hub import hf_hub_download

PretrainedConfig.cache_dir = './'
os.environ['HF_HOME'] = './'
os.environ['HF_DATASETS_CACHE']='./'
HUGGINGFACE_CACHE_DIR = "./"
os.environ['TRANSFORMERS_CACHE'] = "./"
# Model and quantization setup
MODEL_ID = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
CONFIG_CLASS = transformers.BitsAndBytesConfig
MODEL_CLASS = transformers.LlavaOnevisionForConditionalGeneration
PROCESSOR_CLASS = transformers.LlavaOnevisionProcessor

MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

GENERATE_KWARGS = {
    "do_sample": True,
    "top_p":0.9, "top_k":2,"max_new_tokens":100,
    "max_length": 512,
}

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)
min_pixels = 256*28*28
max_pixels = 1024*28*28 
processor = LlavaOnevisionProcessor.from_pretrained(MODEL_ID, cache_dir=HUGGINGFACE_CACHE_DIR)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    MODEL_ID, torch_dtype=torch.float16, device_map="auto", cache_dir=HUGGINGFACE_CACHE_DIR
)

# Helper Functions
def read_video_pyav(container, indices):
    """
    Decode selected frames from a video using PyAV.
    """
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frames(video_path, num_frames=8):
    """
    Uniformly sample `num_frames` from a video file.
    """
    container = av.open(video_path)
    total_frames = container.streams.video[0].frames
    indices = np.arange(0, total_frames, total_frames / 8).astype(int)
    return read_video_pyav(container, indices)

def batchify(data, batch_size):
    """
    Batchify data into chunks of `batch_size`.
    """
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

# Main processing function
def process_videos(video_paths, batch_size, output_dir, prompt_number, max_tokens, top_p, top_k):
    """
    Process videos and generate text outputs.
    """
    #output_dir.mkdir(parents=True, exist_ok=True)
    prompt = f"USER: <video>\n{ALL_PROMPTS[prompt_number]} ASSISTANT:"
    BUFFER = []

    for batch_num, batch in enumerate(tqdm(batchify(video_paths, batch_size), total=len(video_paths) // batch_size)):
        batch_outputs = []

        for video_path in batch:
            video_path_str = str(video_path)
            print(f"Processing: {video_path_str}")
            absolute_path = os.path.abspath(video_path_str)
            video_path_formatted = f"file://{absolute_path}"
            # First, process the video frames
            frames = sample_frames(absolute_path, num_frames=8)

            # Format input for Qwen-VL
            conversation = [
                {
                    "role": "user", 
                    "content": [
                        {"type": "video", "path": absolute_path},
                        {"type": "text", "text": ALL_PROMPTS[prompt_number]}
                    ]
                }
            ]
            
            # Then apply the chat template with processed frames
            inputs = processor.apply_chat_template(
                conversation,
                num_frames=8,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt"
            ).to(model.device, torch.float16)
            
            generate_kwargs = {"max_new_tokens": max_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k}
            outputs = model.generate(
                **inputs,
                **generate_kwargs,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            generated_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
            print(generated_text)
            hidden_states = outputs.hidden_states
            hidden_states_to_save = np.array([
                [
                    lhs.cpu().float().numpy()
                    for lhs in hidden_states[token_num]
                ]
                for token_num in np.arange(1, len(hidden_states))
            ])
            hidden_states_to_save = np.average(hidden_states_to_save[:,:,0,0,:], axis=0)

            batch_outputs.append({"video_path": str(video_path), "generated_text": [generated_text], "language_hidden_states": [hidden_states_to_save]})

        BUFFER.extend(batch_outputs)

        # Save buffer to file after every batch
        with open(output_dir / f"batch_{batch_num + 1}.pkl", "wb") as f:
            pickle.dump(BUFFER, f)
        BUFFER = []

# Command-line interface
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Video Processing with LLaVA-Onevision-Video")
    parser.add_argument("-v","--video-dir", required=True, type=pathlib.Path, help="Directory containing video files")
    parser.add_argument("-b","--batch-size", required=True, type=int, help="Batch size for video processing")
    parser.add_argument("-d","--output-dir", required=True, type=pathlib.Path, help="Directory to save outputs")
    parser.add_argument("-p","--prompt-number", required=True, type=int, help="Prompt number to use")
    parser.add_argument("--max-tokens", default=100, type=int, help="Maximum tokens to generate")
    parser.add_argument("--top-p", default=0.9, type=float, help="Top-p sampling for text generation")
    parser.add_argument("--top-k", default=2, type=int, help="Top-k sampling for text generation")
    args = parser.parse_args()

    # Prepare inputs
    video_dir = args.video_dir
    batch_size = args.batch_size
    output_dir = args.output_dir
    prompt_number = args.prompt_number
    max_tokens = args.max_tokens
    top_p = args.top_p
    top_k = args.top_k

    # Get video paths
    video_paths = natsorted(list(video_dir.glob("*.mp4")))
    print(video_paths[0])

    OUTPUT_DIR = output_dir.joinpath(f"prompt_{args.prompt_number}", MODEL_NAME)

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Process videos
    process_videos(video_paths, batch_size, OUTPUT_DIR, prompt_number, max_tokens, top_p, top_k)
