import argparse
import pathlib
import torch
import av
import numpy as np
import pickle
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor, AutoTokenizer, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
from tqdm.auto import tqdm
import os
from transformers import PretrainedConfig
import transformers
from natsort import natsorted
from prompts import ALL_PROMPTS
from huggingface_hub import hf_hub_download
from pathlib import Path
import cv2
from PIL import Image
import librosa

PretrainedConfig.cache_dir = './'
os.environ['HF_HOME'] = './'
os.environ['HF_DATASETS_CACHE']='./'
HUGGINGFACE_CACHE_DIR = "./"
os.environ['TRANSFORMERS_CACHE'] = "./"
# Model and quantization setup
MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
CONFIG_CLASS = transformers.BitsAndBytesConfig
MODEL_CLASS = transformers.Qwen2AudioForConditionalGeneration
PROCESSOR_CLASS = transformers.AutoProcessor

MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

GENERATE_KWARGS = {
    "do_sample": True,
    "top_p":0.9, "top_k":2,"max_new_tokens":100,
    "max_length": 512,
}

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

processor = AutoProcessor.from_pretrained(MODEL_ID, cache_dir=HUGGINGFACE_CACHE_DIR)
model = Qwen2AudioForConditionalGeneration.from_pretrained(
    MODEL_ID, torch_dtype="auto",device_map="auto", cache_dir=HUGGINGFACE_CACHE_DIR
)

# Helper Functions
def read_video_pyav(container, indices):
    """
    Decode selected frames from a video using PyAV.
    """
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frames(video_path, num_frames=8):
    """
    Uniformly sample `num_frames` from a video file.
    """
    container = av.open(video_path)
    total_frames = container.streams.video[0].frames
    indices = np.arange(0, total_frames, total_frames / 8).astype(int)
    return read_video_pyav(container, indices)

def batchify(data, batch_size):
    """
    Batchify data into chunks of `batch_size`.
    """
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

def batchify(iterable, batch_size):
    """Splits list into batches."""
    return [iterable[i : i + batch_size] for i in range(0, len(iterable), batch_size)]

def extract_frames(video_path, frame_rate=1):
    """
    Extract frames from a video at the specified frame rate.
    :param video_path: Path to the video file
    :param frame_rate: Frames per second
    :return: List of PIL.Image objects
    """
    video_capture = cv2.VideoCapture(str(video_path))
    frames = []
    
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    frame_interval = int(fps / frame_rate) if fps else 1

    success, frame = video_capture.read()
    count = 0
    while success:
        if count % frame_interval == 0:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))
        success, frame = video_capture.read()
        count += 1

    video_capture.release()
    
    return frames if frames else None  # Ensure frames are not empty

# Main processing function
def process_audios(audio_paths, batch_size, output_dir, prompt_number, max_tokens=128, top_p=0.9, top_k=2):
    """
    Process videos using Qwen2.5-audio and generate text outputs.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    if prompt_number not in ALL_PROMPTS:
        raise ValueError(f"Invalid prompt_number {prompt_number}. Choose from {list(ALL_PROMPTS.keys())}.")

    prompt = f"USER: <audio>\n{ALL_PROMPTS[prompt_number]} ASSISTANT:"
    BUFFER = []

    for batch_num, batch in enumerate(tqdm(batchify(audio_paths, batch_size), total=len(audio_paths) // batch_size + 1)):
        batch_outputs = []

        for audio_path in batch:
            audio_path_str = str(audio_path)
            print(f"Processing: {audio_path_str}")
            absolute_path = os.path.abspath(audio_path_str)
            audio_path_formatted = f"file://{absolute_path}"

            # Format input for Qwen-VL
            conversation = [
            {
                    "role": "user",
                    "content": [
                        {"type": "audio", "path": absolute_path},  # ✅ Use the actual audio file
                        {"type": "text", "text": f"{ALL_PROMPTS[prompt_number]} <|AUDIO|>"}
                    ]
                }
            ]

            text = processor.apply_chat_template(
                conversation,
                add_generation_prompt=True,
                tokenize=False,
                return_dict=True,
                return_tensors="pt"
            )

            audios = []
            for message in conversation:
                if isinstance(message["content"], list):
                    for ele in message["content"]:
                        if ele["type"] == "audio":
                            audios.append(
                                librosa.load(
                                    ele['path'], 
                                    sr=processor.feature_extractor.sampling_rate)[0]
                            )

            inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
            inputs.input_ids = inputs.input_ids.to(model.device)

            generate_kwargs = {
                "max_new_tokens": max_tokens,
                "do_sample": True,
                "top_p": top_p,
                "top_k": top_k
            }

            # Generate text
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,  
                    **generate_kwargs,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                )

            generated_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
            print(f"Generated Output: {generated_text}")

            # Extract hidden states
            hidden_states = outputs.hidden_states

            if hidden_states:
                hidden_states_to_save = np.array([
                    [
                        lhs.cpu().float().numpy()
                        for lhs in hidden_states[token_num]
                    ]
                    for token_num in np.arange(1, len(hidden_states))
                ])
                hidden_states_to_save = np.average(hidden_states_to_save[:, :, 0, 0, :], axis=0)
            else:
                hidden_states_to_save = None

            batch_outputs.append({
                "audio_path": str(audio_path),
                "generated_text": generated_text,
                "language_hidden_states": hidden_states_to_save
            })

        BUFFER.extend(batch_outputs)

        # Save buffer to file after every batch
        batch_file = output_dir / f"batch_{batch_num + 1}.pkl"
        with open(batch_file, "wb") as f:
            pickle.dump(BUFFER, f)

        print(f"✅ Saved batch {batch_num + 1} results to {batch_file}")
        BUFFER = []

    print("✅ Audio processing complete!")

# Command-line interface
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Audio Processing with Qwen2.5 Audio Language Model")
    parser.add_argument("-v","--audio-dir", required=True, type=pathlib.Path, help="Directory containing audio files")
    parser.add_argument("-b","--batch-size", required=True, type=int, help="Batch size for audio processing")
    parser.add_argument("-d","--output-dir", required=True, type=pathlib.Path, help="Directory to save outputs")
    parser.add_argument("-p","--prompt-number", required=True, type=int, help="Prompt number to use")
    parser.add_argument("--max-tokens", default=100, type=int, help="Maximum tokens to generate")
    parser.add_argument("--top-p", default=0.9, type=float, help="Top-p sampling for text generation")
    parser.add_argument("--top-k", default=2, type=int, help="Top-k sampling for text generation")
    args = parser.parse_args()

    # Prepare inputs
    audio_dir = args.audio_dir
    batch_size = args.batch_size
    output_dir = args.output_dir
    prompt_number = args.prompt_number
    max_tokens = args.max_tokens
    top_p = args.top_p
    top_k = args.top_k

    # Get video paths
    audio_paths = natsorted(list(audio_dir.glob("*.wav")))
    print(audio_paths[0])

    OUTPUT_DIR = output_dir.joinpath(f"prompt_{args.prompt_number}", MODEL_NAME)

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Process audios
    process_audios(audio_paths, batch_size, OUTPUT_DIR, prompt_number, max_tokens, top_p, top_k)
