import json
import os
import torch
import torchvision.transforms as T
from transformers import AutoModel, AutoTokenizer, AutoConfig
from PIL import Image
from decord import VideoReader, cpu
from torchvision.transforms.functional import InterpolationMode
import math
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import gc
import time
Image.MAX_IMAGE_PIXELS = None

# Configuration
MODEL_NAME = "OpenGVLab/InternVL3-8B"
DATASET_PATH = "../dataset/dataset_vqa/"
INPUT_JSON = "e5v_results.json"
OUTPUT_JSON = "internvl_e5v_results.json"
IMAGE_EXTENSIONS = ['.jpg', '.png', '.jpeg', '.gif']
VIDEO_EXTENSIONS = ['.mp4']
ALL_EXTENSIONS = IMAGE_EXTENSIONS + VIDEO_EXTENSIONS
MEDIA_EXTENSIONS = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']

# Image processing configuration
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
DEFAULT_INPUT_SIZE = 448
DEFAULT_MIN_NUM = 1
DEFAULT_MAX_NUM = 12
USE_THUMBNAIL = True

# Video processing configuration
DEFAULT_NUM_SEGMENTS = 8
VIDEO_MAX_NUM = 1
VIDEO_SEGMENTS_FOR_PROCESSING = 4
IMAGE_PATCHES_PER_IMAGE = 8

# Model configuration
TORCH_DTYPE = torch.bfloat16
LOAD_IN_8BIT = False
LOW_CPU_MEM_USAGE = True
USE_FLASH_ATTN = True
USE_FAST_TOKENIZER = False

# Generation configuration
MAX_NEW_TOKENS = 512
DO_SAMPLE = True
TEMPERATURE = 0.4

# Processing configuration
CHECKPOINT_INTERVAL = 10
MEMORY_CLEANUP_INTERVAL = 5

print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_PATH}")
print(f"Input: {INPUT_JSON}")
print(f"Output: {OUTPUT_JSON}")

# Image processing utilities
def build_transform(input_size):
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=DEFAULT_MIN_NUM, max_num=DEFAULT_MAX_NUM, image_size=DEFAULT_INPUT_SIZE, use_thumbnail=USE_THUMBNAIL):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # Calculate target aspect ratios
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # Find best matching aspect ratio
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # Calculate target dimensions
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # Resize image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # Extract image patches
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=DEFAULT_INPUT_SIZE, max_num=DEFAULT_MAX_NUM):
    try:
        image = Image.open(image_file).convert('RGB')
        transform = build_transform(input_size=input_size)
        images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        return pixel_values
    except Exception as e:
        print(f"Error loading image {image_file}: {e}")
        return None

# Video processing utilities
def get_index(bound, fps, max_frame, first_idx=0, num_segments=DEFAULT_NUM_SEGMENTS):
    if bound:
        start, end = bound[0], bound[1]
    else:
        start, end = -100000, 100000
    start_idx = max(first_idx, round(start * fps))
    end_idx = min(round(end * fps), max_frame)
    seg_size = float(end_idx - start_idx) / num_segments
    frame_indices = np.array([
        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
        for idx in range(num_segments)
    ])
    return frame_indices

def load_video(video_path, bound=None, input_size=DEFAULT_INPUT_SIZE, max_num=VIDEO_MAX_NUM, num_segments=DEFAULT_NUM_SEGMENTS):
    try:
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
        max_frame = len(vr) - 1
        fps = float(vr.get_avg_fps())

        pixel_values_list, num_patches_list = [], []
        transform = build_transform(input_size=input_size)
        frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
        for frame_index in frame_indices:
            img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
            img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
            pixel_values = [transform(tile) for tile in img]
            pixel_values = torch.stack(pixel_values)
            num_patches_list.append(pixel_values.shape[0])
            pixel_values_list.append(pixel_values)
        pixel_values = torch.cat(pixel_values_list)
        return pixel_values, num_patches_list
    except Exception as e:
        print(f"Error loading video {video_path}: {e}")
        return None, None

# File handling utilities
def find_media_resources(filename, base_dir=DATASET_PATH):
    """
    Find media resources supporting single files and folders.
    Returns:
        - Single file: [file_path]
        - Folder: [folder_path/image1.jpg, folder_path/image2.png, ...]
        - Not found: []
    """
    import re
    # Strip file extensions
    base_name = re.sub(r'\.(mp4|jpg|png|wav|mp3|pdf|jpeg|gif|txt)$', '', filename)
    
    # Check for folder match
    folder_path = os.path.join(base_dir, base_name)
    if os.path.isdir(folder_path):
        # Collect folder images
        return get_all_images_from_folder(folder_path)
    
    # Search for single file
    extensions = ['', '.mp4', '.jpg', '.png', '.jpeg']
    for ext in extensions:
        file_path = os.path.join(base_dir, base_name + ext)
        if os.path.isfile(file_path):
            return [file_path]
    
    # Return empty if not found
    return []

def get_all_images_from_folder(folder_path):
    """Get all image files from folder."""
    image_files = []

    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if os.path.isfile(file_path):
            ext = os.path.splitext(filename)[1].lower()
            if ext in IMAGE_EXTENSIONS:
                image_files.append(file_path)

    return sorted(image_files)

def strip_extension(filename):
    for ext in MEDIA_EXTENSIONS:
        if filename.lower().endswith(ext):
            return filename[:-len(ext)]
    return filename

# Main processing pipeline
def process_data():
    with open(INPUT_JSON, 'r') as f:
        data = json.load(f)
    
    results = data['results']
    
    print(f"Loading {MODEL_NAME}")

    # Model initialization
    model = AutoModel.from_pretrained(
        MODEL_NAME,
        torch_dtype=TORCH_DTYPE,
        load_in_8bit=LOAD_IN_8BIT,
        low_cpu_mem_usage=LOW_CPU_MEM_USAGE,
        use_flash_attn=USE_FLASH_ATTN,
        trust_remote_code=True).eval()

    model = model.cuda()

    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        use_fast=USE_FAST_TOKENIZER
    )

    generation_config = dict(
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=DO_SAMPLE,
        temperature=TEMPERATURE
    )
    
    # Process samples individually
    new_results = []
    
    # Load existing results
    if os.path.exists(OUTPUT_JSON):
        try:
            with open(OUTPUT_JSON, 'r') as f:
                new_results = json.load(f)
            print(f"Loaded {len(new_results)} existing results")
        except:
            pass
    
    # Determine remaining samples
    processed_ids = set(item.get("question") for item in new_results)
    to_process = [item for item in results if item["question"] not in processed_ids]
    
    print(f"Processing {len(to_process)} new items")
    
    for i, item in enumerate(tqdm(to_process)):
        processed_batch = process_batch([item], model, tokenizer, generation_config)
        new_results.extend(processed_batch)
        
        if (i + 1) % CHECKPOINT_INTERVAL == 0:
            with open(OUTPUT_JSON, 'w') as f:
                json.dump(new_results, f, indent=2)
            torch.cuda.empty_cache()
            gc.collect()
    
    with open(OUTPUT_JSON, 'w') as f:
        json.dump(new_results, f, indent=2)
    
    print(f"Results saved to {OUTPUT_JSON}. Total items: {len(new_results)}")

def process_batch(batch, model, tokenizer, generation_config):
    results = []
    
    for item in batch:
        question = item['question']
        top_5_retrieved = item['top_5_retrieved'][:5]
        positive = item['positive']
        answer = item.get('answer', '')
        
        processed_files = []
        for file_name in top_5_retrieved[:5]:
            media_paths = find_media_resources(file_name, DATASET_PATH)
            for file_path in media_paths:
                ext = os.path.splitext(file_path)[1]
                processed_files.append({
                    'original_name': file_name,
                    'base_name': os.path.splitext(os.path.basename(file_path))[0],
                    'path': file_path,
                    'extension': ext
                })
        print(f"Processing {len(processed_files)} files")
        
        response = ""
        if processed_files:
            try:
                pixel_values_list = []
                num_patches_list = []
                
                for file_info in processed_files:
                    if file_info['extension'] in IMAGE_EXTENSIONS:
                        pixel_values = load_image(file_info['path'], max_num=IMAGE_PATCHES_PER_IMAGE)
                        if pixel_values is not None:
                            num_patches_list.append(pixel_values.size(0))
                            pixel_values_list.append(pixel_values)

                    elif file_info['extension'] in VIDEO_EXTENSIONS:
                        pixels, patches = load_video(file_info['path'], num_segments=VIDEO_SEGMENTS_FOR_PROCESSING)
                        if pixels is not None:
                            pixel_values_list.append(pixels)
                            num_patches_list.extend(patches)
                
                if pixel_values_list:
                    all_pixel_values = torch.cat(pixel_values_list).to(TORCH_DTYPE).cuda()
                    
                    if len(num_patches_list) > 1:
                        prompt_prefix = ''.join([f'Image-{i+1}: <image>\n' for i in range(len(num_patches_list))])
                        prompt = prompt_prefix + question
                    else:
                        prompt = '<image>\n' + question
                    
                    response = model.chat(
                        tokenizer, 
                        all_pixel_values, 
                        prompt, 
                        generation_config, 
                        num_patches_list=num_patches_list if len(num_patches_list) > 1 else None
                    )
                    
                    del all_pixel_values
                    torch.cuda.empty_cache()
            
            except Exception as e:
                print(f"Error processing item with question '{question}': {e}")
                response = f"Error: {str(e)}"
        
        results.append({
            "question": question,
            "positive": positive,
            "top_5_retrieved": top_5_retrieved,
            "answer": answer,
            "response": response
        })
    
    return results

if __name__ == "__main__":
    process_data()