import os
import json
import torch
import gc
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from tqdm import tqdm
import re
Image.MAX_IMAGE_PIXELS = None

# Configuration
MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
DATASET_PATH = "../dataset/dataset_vqa/"
INPUT_JSON = "e5v_results.json"
OUTPUT_JSON = "qwen2vl_e5v_results.json"
DEVICE = "cuda:0"
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif']
MEDIA_EXTENSIONS = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']

# Model configuration
MIN_PIXELS = 224 * 28 * 28
MAX_PIXELS = 336 * 28 * 28
MAX_NEW_TOKENS = 128
TEMPERATURE = 0.4
MAX_SEQUENCE_LENGTH = 30000
VIDEO_FPS = 0.5
IMAGE_SIZE = 280

print(f"Using device: {DEVICE}")
torch.cuda.set_device(0)

# Model setup
# Model initialization
def setup_model():
    print(f"Loading {MODEL_NAME}")

    model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map=DEVICE,
        low_cpu_mem_usage=True
    )

    processor = AutoProcessor.from_pretrained(
        MODEL_NAME,
        min_pixels=MIN_PIXELS,
        max_pixels=MAX_PIXELS
    )

    model = model.to(DEVICE)

    print(f"Model loaded to {DEVICE}")
    print(f"GPU memory usage: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

    return model, processor

# File search utilities
# File handling utilities
def find_media_resources(filename, base_dir=DATASET_PATH):
    """Find media resources for files and folders."""
    base_name = re.sub(r'\.(mp4|jpg|png|wav|mp3|pdf|jpeg|gif|txt)$', '', filename)

    folder_path = os.path.join(base_dir, base_name)
    if os.path.isdir(folder_path):
        return get_all_images_from_folder(folder_path)

    extensions = ['', '.mp4', '.jpg', '.png', '.jpeg']
    for ext in extensions:
        file_path = os.path.join(base_dir, base_name + ext)
        if os.path.isfile(file_path):
            return [file_path]

    return []

def get_all_images_from_folder(folder_path):
    """Get all image files from folder."""
    image_files = []

    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if os.path.isfile(file_path):
            ext = os.path.splitext(filename)[1].lower()
            if ext in IMAGE_EXTENSIONS:
                image_files.append(file_path)

    return sorted(image_files)

# Sample processing
def process_sample(model, processor, sample):
    question = sample["question"]
    top_5_files = sample["top_5_retrieved"]
    answer = sample.get("answer", "")
    
    # Collect media files
    valid_items = []
    for file in top_5_files[:5]:
        media_paths = find_media_resources(file)
        if media_paths:
            valid_items.append(media_paths)

    valid_files = []
    for item_files in valid_items:
        valid_files.extend(item_files)
    
    # Handle no files found
    if not valid_files:
        return {
            "question": question,
            "positive": sample.get("positive", []),
            "top_5_retrieved": top_5_files,
            "answer": answer,
            "response": "No relevant media files found"
        }
    
    print(f"Processing {len(valid_files)} files")
    
    # Prepare model input
    messages = [
        {
            "role": "user",  
            "content": []
        }
    ]
    
    # Add media to content
    for file_path in valid_files:
        if file_path.endswith('.mp4'):
            messages[0]["content"].append({
                "type": "video",
                "video": file_path,
                "fps": VIDEO_FPS,
                "max_pixels": IMAGE_SIZE * IMAGE_SIZE
            })
        else:
            messages[0]["content"].append({
                "type": "image",
                "image": f"file://{file_path}",
                "resized_height": IMAGE_SIZE,
                "resized_width": IMAGE_SIZE
            })
    
    # Add text query
    messages[0]["content"].append({
        "type": "text",
        "text": question
    })
    
    try:
        # Prepare inputs
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        )
        
        if inputs.input_ids.shape[1] > MAX_SEQUENCE_LENGTH:
            print(f"Warning: Input sequence too long ({inputs.input_ids.shape[1]}), truncating to {MAX_SEQUENCE_LENGTH}")
            inputs.input_ids = inputs.input_ids[:, :MAX_SEQUENCE_LENGTH]
            if hasattr(inputs, 'attention_mask'):
                inputs.attention_mask = inputs.attention_mask[:, :MAX_SEQUENCE_LENGTH]
        
        # Transfer to GPU
        inputs = inputs.to(DEVICE)
        
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=True,
                temperature=TEMPERATURE,
                pad_token_id=processor.tokenizer.eos_token_id
            )
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            response = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]
        
        del inputs, generated_ids, generated_ids_trimmed
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()
        
        # Return processed result
        return {
            "question": question,
            "positive": sample.get("positive", []),
            "top_5_retrieved": top_5_files,
            "answer": answer,
            "response": response
        }
    except Exception as e:
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()
        
        # Return error result
        return {
            "question": question,
            "positive": sample.get("positive", []),
            "top_5_retrieved": top_5_files,
            "answer": answer,
            "response": f"Processing error: {str(e)}"
        }

# Sequential processing
def sequential_process(samples, model, processor):
    results = []
    
    print(f"Starting processing {len(samples)} samples on device: {DEVICE}")
    
    # Process with progress tracking
    for i, sample in enumerate(tqdm(samples, desc="Processing samples")):
        # Monitor GPU memory
        if i % 10 == 0:
            memory_used = torch.cuda.memory_allocated(0) / 1024**3
            print(f"Current GPU memory usage: {memory_used:.2f} GB")
        
        # Process current sample
        result = process_sample(model, processor, sample)
        results.append(result)
        
        with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        
        # Periodic memory cleanup
        if (i + 1) % 5 == 0:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()
    
    return results

# Main processing function
def main():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA unavailable, cannot use GPU")

    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
    print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

    torch.cuda.set_device(0)

    # Load dataset
    with open(INPUT_JSON, "r") as f:
        data = json.load(f)

    model, processor = setup_model()

    samples = data["results"]

    results = sequential_process(samples, model, processor)

    with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print(f"Results saved to {OUTPUT_JSON}. Total samples: {len(results)}")

    torch.cuda.empty_cache()
    print(f"Final GPU memory usage: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

if __name__ == "__main__":
    main()