import argparse
import pathlib
import torch
import av
import numpy as np
import pickle
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration, BitsAndBytesConfig
from tqdm.auto import tqdm
import os
from transformers import PretrainedConfig
import transformers
from natsort import natsorted
from prompts import ALL_PROMPTS
from huggingface_hub import hf_hub_download

PretrainedConfig.cache_dir = './'
os.environ['HF_HOME'] = './'
os.environ['HF_DATASETS_CACHE']='./'
HUGGINGFACE_CACHE_DIR = "./"
os.environ['TRANSFORMERS_CACHE'] = "./"
# Model and quantization setup
MODEL_ID = "LanguageBind/Video-LLaVA-7B-hf"
CONFIG_CLASS = transformers.BitsAndBytesConfig
MODEL_CLASS = transformers.VideoLlavaForConditionalGeneration
PROCESSOR_CLASS = transformers.VideoLlavaProcessor

MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

GENERATE_KWARGS = {
    "do_sample": True,
    "top_p":0.9, "top_k":2,"max_new_tokens":100,
    "max_length": 512,
}

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)
processor = VideoLlavaProcessor.from_pretrained(MODEL_ID, cache_dir=HUGGINGFACE_CACHE_DIR)
processor.patch_size = 14  # Example patch size (can be adjusted as required)
processor.vision_feature_select_strategy = "default"  # Example strategy (e.g., "average_pooling" or "max_pooling")
model = VideoLlavaForConditionalGeneration.from_pretrained(
    MODEL_ID, quantization_config=quantization_config, device_map="auto", cache_dir=HUGGINGFACE_CACHE_DIR
)

# Helper Functions
def read_video_pyav(container, indices):
    """
    Decode selected frames from a video using PyAV.
    """
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frames(video_path, num_frames=8):
    """
    Uniformly sample `num_frames` from a video file.
    """
    container = av.open(video_path)
    total_frames = container.streams.video[0].frames
    indices = np.arange(0, total_frames, total_frames / 8).astype(int)
    return read_video_pyav(container, indices)

def batchify(data, batch_size):
    """
    Batchify data into chunks of `batch_size`.
    """
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

# Main processing function
def process_videos(video_paths, batch_size, output_dir, prompt_number, max_tokens, top_p, top_k):
    """
    Process videos and generate text outputs.
    """
    #output_dir.mkdir(parents=True, exist_ok=True)
    prompt = f"USER: <video>\n{ALL_PROMPTS[prompt_number]} ASSISTANT:"
    BUFFER = []

    for batch_num, batch in enumerate(tqdm(batchify(video_paths, batch_size), total=len(video_paths) // batch_size)):
        batch_outputs = []

        for video_path in batch:
            video_frames = sample_frames(video_path)
            inputs = processor(prompt, videos=video_frames, return_tensors="pt").to(model.device)
            generate_kwargs = {"max_new_tokens":100, "do_sample":True, "top_p":0.9, "top_k":2}
            outputs = model.generate(
                **inputs,
                **generate_kwargs,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            generated_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
            print(generated_text)
            hidden_states = outputs.hidden_states
            hidden_states_to_save = np.array([
                [
                    lhs.cpu().float().numpy()
                    for lhs in hidden_states[token_num]
                ]
                for token_num in np.arange(1, len(hidden_states))
            ])
            hidden_states_to_save = np.average(hidden_states_to_save[:,:,0,0,:], axis=0)

            batch_outputs.append({"video_path": str(video_path), "generated_text": [generated_text], "language_hidden_states": [hidden_states_to_save]})

        BUFFER.extend(batch_outputs)

        # Save buffer to file after every batch
        with open(output_dir / f"batch_{batch_num + 1}.pkl", "wb") as f:
            pickle.dump(BUFFER, f)
        BUFFER = []

# Command-line interface
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Video Processing with Video-LLaVA")
    parser.add_argument("-v", "--video-dir", required=True, type=pathlib.Path, help="Directory containing video files")
    parser.add_argument("-b", "--batch-size", required=True, type=int, help="Batch size for video processing")
    parser.add_argument("-d", "--output-dir", required=True, type=pathlib.Path, help="Directory to save outputs")
    parser.add_argument("-p", "--prompt-number", required=True, type=int, help="Prompt number to use")
    parser.add_argument("--max-tokens", default=100, type=int, help="Maximum tokens to generate")
    parser.add_argument("--top-p", default=0.9, type=float, help="Top-p sampling for text generation")
    parser.add_argument("--top-k", default=2, type=int, help="Top-k sampling for text generation")
    args = parser.parse_args()

    # Prepare inputs
    video_dir = args.video_dir
    batch_size = args.batch_size
    output_dir = args.output_dir
    prompt_number = args.prompt_number
    max_tokens = args.max_tokens
    top_p = args.top_p
    top_k = args.top_k

    # Get video paths
    video_paths = natsorted(list(video_dir.glob("*.mp4")))
    print(video_paths[0])

    OUTPUT_DIR = output_dir.joinpath(f"prompt_{args.prompt_number}", MODEL_NAME)

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Process videos
    process_videos(video_paths, batch_size, OUTPUT_DIR, prompt_number, max_tokens, top_p, top_k)
