import os
import json
import torch
import numpy as np
from PIL import Image
from decord import VideoReader, cpu
import whisper
import librosa
import time
import argparse
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None

# Configuration
MODEL_NAME = 'THUdyh/Ola-7b'
DATASET_PATH = '../dataset/dataset_vqa'
INPUT_JSON = 'e5v_results.json'
OUTPUT_JSON = 'ola_results.json'
DEFAULT_NUM_GPUS = 1
CONV_MODE = "qwen_1_5"

# File extensions
IMAGE_EXTENSIONS = ['.jpg', '.png', '.jpeg']
VIDEO_EXTENSIONS = ['.mp4']
AUDIO_EXTENSIONS = ['.wav', '.mp3']
MEDIA_EXTENSIONS = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']
SUPPORTED_EXTENSIONS = ['.jpg', '.png', '.jpeg', '.mp4']

# Audio processing configuration
AUDIO_SAMPLE_RATE = 16000
AUDIO_CHUNK_LIMIT = 480000
MEL_SPECTOGRAM_BINS = 128
MAX_MEL_CHUNKS = 25

# Image processing configuration
MAX_IMAGE_SIZE = 1536
IMAGE_RESIZE_METHOD = Image.LANCZOS
PLACEHOLDER_IMAGE_SIZE = (224, 224)
PLACEHOLDER_VIDEO_SIZE = (384, 384)

# Video processing configuration
MAX_VIDEO_FRAMES = 64
VIDEO_FRAME_HEIGHT = 384
VIDEO_FRAME_WIDTH = 384

# Model configuration
TORCH_DTYPE = torch.bfloat16
MAX_NEW_TOKENS = 1024
GENERATION_TEMPERATURE = 0.4
TOP_P = None
NUM_BEAMS = 1
PAD_TOKEN_ID = 151643
MAX_RETRIEVED_ITEMS = 1

# Speech processing configuration
SPEECH_PLACEHOLDER_SHAPE = (1, 3000, 128)
SPEECH_LENGTH_PLACEHOLDER = 3000
SPEECH_WAV_PLACEHOLDER_SHAPE = (1, 480000)
SPEECH_CHUNKS_PLACEHOLDER = 1

# Environment variables configuration
ENV_CONFIG = {
    'LOWRES_RESIZE': '384x32',
    'HIGHRES_BASE': '0x32',
    'VIDEO_RESIZE': '0x64',
    'VIDEO_MAXRES': '480',
    'VIDEO_MINRES': '288',
    'MAXRES': '1536',
    'MINRES': '0',
    'FORCE_NO_DOWNSAMPLE': '1',
    'LOAD_VISION_EARLY': '1',
    'PAD2STRIDE': '1'
}

# Set environment variables
for key, value in ENV_CONFIG.items():
    os.environ[key] = value

print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_PATH}")
print(f"Input: {INPUT_JSON}")
print(f"Output: {OUTPUT_JSON}")
print(f"Default GPUs: {DEFAULT_NUM_GPUS}")

from ola.conversation import conv_templates, SeparatorStyle
from ola.model.builder import load_pretrained_model
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
from ola.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN


# Audio processing utilities
def load_audio(audio_file_name):
    """Load and process audio file."""
    speech_wav, samplerate = librosa.load(audio_file_name, sr=AUDIO_SAMPLE_RATE)
    if len(speech_wav.shape) > 1:
        speech_wav = speech_wav[:, 0]
    speech_wav = speech_wav.astype(np.float32)
    speechs = []
    speech_wavs = []

    if len(speech_wav) <= AUDIO_CHUNK_LIMIT:
        speech = whisper.pad_or_trim(speech_wav)
        speech_wav = whisper.pad_or_trim(speech_wav)
        speechs.append(speech)
        speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
    else:
        for i in range(0, len(speech_wav), AUDIO_CHUNK_LIMIT):
            chunk = speech_wav[i : i + AUDIO_CHUNK_LIMIT]
            if len(chunk) < AUDIO_CHUNK_LIMIT:
                chunk = whisper.pad_or_trim(chunk)
            speechs.append(chunk)
            speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))

    mels = []
    for chunk in speechs:
        chunk = whisper.log_mel_spectrogram(chunk, n_mels=MEL_SPECTOGRAM_BINS).permute(1, 0).unsqueeze(0)
        mels.append(chunk)

    mels = torch.cat(mels, dim=0)
    speech_wavs = torch.cat(speech_wavs, dim=0)
    if mels.shape[0] > MAX_MEL_CHUNKS:
        mels = mels[:MAX_MEL_CHUNKS]
        speech_wavs = speech_wavs[:MAX_MEL_CHUNKS]

    speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
    speech_chunks = torch.LongTensor([mels.shape[0]])
    return mels, speech_length, speech_chunks, speech_wavs

# File handling utilities
def find_file_with_extensions(dataset_dir, filename):
    """Find files with extension matching."""
    base_filename = filename
    for ext in MEDIA_EXTENSIONS:
        if filename.lower().endswith(ext.lower()):
            base_filename = filename[:-len(ext)]
            break

    for ext in SUPPORTED_EXTENSIONS:
        possible_file = os.path.join(dataset_dir, base_filename + ext)
        if os.path.exists(possible_file):
            return possible_file

    return os.path.join(dataset_dir, filename)

# Image processing utilities
def preprocess_images(image_files, image_processor):
    """Process image files for model input."""
    image_tensors = []
    image_highres_tensors = []
    image_sizes = []

    for img_file in image_files:
        try:
            image = Image.open(img_file)

            if max(image.size) > MAX_IMAGE_SIZE:
                ratio = MAX_IMAGE_SIZE / max(image.size)
                new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
                image = image.resize(new_size, IMAGE_RESIZE_METHOD)

            image_processor.do_resize = False
            image_processor.do_center_crop = False
            image_tensor, image_highres_tensor = process_anyres_highres_image(image, image_processor)

            image_tensors.append(image_tensor)
            image_highres_tensors.append(image_highres_tensor)
            image_sizes.append(image.size)
        except Exception as e:
            print(f"Error processing image {img_file}: {e}")
            placeholder_tensor = torch.zeros(1, 3, *PLACEHOLDER_IMAGE_SIZE)
            image_tensors.append(placeholder_tensor)
            image_highres_tensors.append(placeholder_tensor)
            image_sizes.append(PLACEHOLDER_IMAGE_SIZE)

    return image_tensors, image_highres_tensors, image_sizes

# Video processing utilities
def preprocess_video(video_file, image_processor):
    """Process video files for model input."""
    try:
        vr = VideoReader(video_file, ctx=cpu(0))
        total_frame_num = len(vr)
        fps = round(vr.get_avg_fps())
        uniform_sampled_frames = np.linspace(0, total_frame_num - 1, min(MAX_VIDEO_FRAMES, total_frame_num), dtype=int)
        frame_idx = uniform_sampled_frames.tolist()
        spare_frames = vr.get_batch(frame_idx).asnumpy()
        video = [Image.fromarray(frame) for frame in spare_frames]

        video_processed = []
        for idx, frame in enumerate(video):
            image_processor.do_resize = False
            image_processor.do_center_crop = False
            frame = process_anyres_video(frame, image_processor)
            video_processed.append(frame.unsqueeze(0))

        video_processed = torch.cat(video_processed, dim=0).to(TORCH_DTYPE)
        return (video_processed, video_processed), PLACEHOLDER_VIDEO_SIZE, "video"
    except Exception as e:
        print(f"Error processing video {video_file}: {e}")
        dummy_frame = torch.zeros(1, 3, *PLACEHOLDER_VIDEO_SIZE).to(TORCH_DTYPE)
        return (dummy_frame, dummy_frame), PLACEHOLDER_VIDEO_SIZE, "video"

# Model inference
def run_inference(model, tokenizer, image_processor, item, dataset_dir, device_id=0):
    """Execute model inference on data item."""
    device = f'cuda:{device_id}'

    question = item["question"]
    top_5_files = item["top_5_retrieved"][:MAX_RETRIEVED_ITEMS]
    
    # Collect media files
    media_files = []
    for file in top_5_files:
        found_file = find_file_with_extensions(dataset_dir, file)
        if os.path.exists(found_file):
            media_files.append(found_file)
        else:
            print(f"File does not exist: {found_file}")
    
    # Handle no media files case
    if not media_files:
        return "No media files found for analysis."
    
    is_video = any(file.lower().endswith(tuple(VIDEO_EXTENSIONS)) for file in media_files)
    
    # Prepare model input
    if is_video:
            for file in media_files:
            if file.lower().endswith(tuple(VIDEO_EXTENSIONS)):
                video_file = file
                break
        video_data = preprocess_video(video_file, image_processor)
        modality = "video"
    else:
        # Process image input
        image_tensors, image_highres_tensors, image_sizes = preprocess_images(media_files, image_processor)
        if all(x.shape == image_tensors[0].shape for x in image_tensors):
            image_tensor = torch.stack(image_tensors, dim=0)
        else:
            image_tensor = image_tensors
            
        if all(x.shape == image_highres_tensors[0].shape for x in image_highres_tensors):
            image_highres_tensor = torch.stack(image_highres_tensors, dim=0)
        else:
            image_highres_tensor = image_highres_tensors
            
        if isinstance(image_tensor, list):
            image_tensor = [img.to(TORCH_DTYPE).to(device) for img in image_tensor]
        else:
            image_tensor = image_tensor.to(TORCH_DTYPE).to(device)
            
        if isinstance(image_highres_tensor, list):
            image_highres_tensor = [img.to(TORCH_DTYPE).to(device) for img in image_highres_tensor]
        else:
            image_highres_tensor = image_highres_tensor.to(TORCH_DTYPE).to(device)
        
        modality = "image"
    
    qs = DEFAULT_IMAGE_TOKEN + "\n" + question
    conv = conv_templates[CONV_MODE].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    
    attention_masks = input_ids.ne(PAD_TOKEN_ID).long().to(device)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    
    gen_kwargs = {
        "max_new_tokens": MAX_NEW_TOKENS,
        "temperature": GENERATION_TEMPERATURE,
        "top_p": TOP_P,
        "num_beams": NUM_BEAMS
    }
    
    speechs = [torch.zeros(*SPEECH_PLACEHOLDER_SHAPE).to(TORCH_DTYPE).to(device)]
    speech_lengths = [torch.LongTensor([SPEECH_LENGTH_PLACEHOLDER]).to(device)]
    speech_wavs = [torch.zeros(*SPEECH_WAV_PLACEHOLDER_SHAPE).to(device)]
    speech_chunks = [torch.LongTensor([SPEECH_CHUNKS_PLACEHOLDER]).to(device)]
    
    # Run model inference
    with torch.inference_mode():
        if modality == "video":
            video_data_device = (
                (video_data[0][0].to(device), video_data[0][1].to(device)),
                video_data[1],
                video_data[2]
            )
            output_ids = model.generate(
                inputs=input_ids,
                images=video_data_device[0][0],
                images_highres=video_data_device[0][1],
                modalities=video_data_device[2],
                speech=speechs,
                speech_lengths=speech_lengths,
                speech_chunks=speech_chunks,
                speech_wav=speech_wavs,
                attention_mask=attention_masks,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
                do_sample=True if gen_kwargs["temperature"] > 0 else False,
                temperature=gen_kwargs["temperature"],
                top_p=gen_kwargs["top_p"],
                num_beams=gen_kwargs["num_beams"],
                max_new_tokens=gen_kwargs["max_new_tokens"],
            )
        else:  # Image
            output_ids = model.generate(
                inputs=input_ids,
                images=image_tensor,
                images_highres=image_highres_tensor,
                image_sizes=image_sizes,
                modalities=['image'] * len(image_sizes),
                speech=speechs,
                speech_lengths=speech_lengths,
                speech_chunks=speech_chunks,
                speech_wav=speech_wavs,
                attention_mask=attention_masks,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
                do_sample=True if gen_kwargs["temperature"] > 0 else False,
                temperature=gen_kwargs["temperature"],
                top_p=gen_kwargs["top_p"],
                num_beams=gen_kwargs["num_beams"],
                max_new_tokens=gen_kwargs["max_new_tokens"],
            )
    
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[:-len(stop_str)]
    outputs = outputs.strip()
    
    return outputs

# Main processing pipeline
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default=MODEL_NAME)
    parser.add_argument('--input_json', type=str, default=INPUT_JSON)
    parser.add_argument('--output_json', type=str, default=OUTPUT_JSON)
    parser.add_argument('--dataset_dir', type=str, default=DATASET_PATH)
    parser.add_argument('--num_gpus', type=int, default=DEFAULT_NUM_GPUS)
    args = parser.parse_args()
    
    # Create output directory
    output_dir = os.path.dirname(args.output_json)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Load dataset
    with open(args.input_json, 'r') as f:
        data = json.load(f)
    
    items = data.get('results', [])
    print(f"Loaded {len(items)} data items for processing")
    
    # Check GPU availability
    available_gpus = min(args.num_gpus, torch.cuda.device_count())
    print(f"Using {available_gpus} GPUs for processing")
    
    # Load models on GPUs
    print("Loading models on all GPUs")
    models = []
    for gpu_id in range(available_gpus):
        print(f"Loading model on GPU {gpu_id}")
        device = f'cuda:{gpu_id}'
        torch.cuda.set_device(device)
        
        tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, None)
        model = model.to(device).eval()
        model = model.to(TORCH_DTYPE)
        
        models.append((tokenizer, model, image_processor, gpu_id))
        print(f"Model loaded on GPU {gpu_id}")
    
    # Initialize results tracking
    results = []
    if os.path.exists(args.output_json):
        try:
            with open(args.output_json, 'r') as f:
                results = json.load(f)
            print(f"Loaded {len(results)} existing results")
        except:
            results = []
    
    # Find processed items
    processed_indices = set()
    for processed_item in results:
        for i, item in enumerate(items):
            if item["question"] == processed_item["question"]:
                processed_indices.add(i)
                break
    
    # Process with GPU rotation
    for i, item in enumerate(tqdm(items, desc="Overall processing progress")):
        # Skip already processed items
        if i in processed_indices:
            print(f"Skipping processed item {i}")
            continue
            
        # Choose GPU for processing
        gpu_index = i % available_gpus
        tokenizer, model, image_processor, gpu_id = models[gpu_index]
        
        try:
            print(f"Processing item {i} on GPU {gpu_id}")
            response = run_inference(model, tokenizer, image_processor, item, args.dataset_dir, gpu_id)
            
            # Save result
            item_copy = item.copy()
            item_copy["response"] = response
            results.append(item_copy)
            
            # Write incremental results
            with open(args.output_json, 'w') as f:
                json.dump(results, f, indent=2)
                
            print(f"Item {i} completed")
                
        except Exception as e:
            print(f"Error processing item on GPU {gpu_id}: {e}")
            # Save error result
            item_copy = item.copy()
            item_copy["response"] = f"Processing error: {str(e)}"
            results.append(item_copy)
            
            # Write error result
            with open(args.output_json, 'w') as f:
                json.dump(results, f, indent=2)
    
    print(f"Results saved to {args.output_json}")

if __name__ == '__main__':
    start_time = time.time()
    main()
    end_time = time.time()
    print(f"Processing time: {end_time - start_time:.2f} seconds") 