import base64
from typing import List, Dict, Any, Optional

import cv2
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, \
    Qwen2_5_VLForConditionalGeneration
import torch.nn as nn

import json
import torch
import os
from qwen_vl_utils import process_vision_info
from tqdm import tqdm
import numpy as np
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
import easyocr

import sys

import requests

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
import torchvision.transforms as T

kl_reduction = 'min'


def cv2_to_base64(image, format='.jpg', quality=100):
    """Convert OpenCV image to Base64 encoding"""
    try:
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] if format == '.jpg' else []
        success, buffer = cv2.imencode(format, image, encode_param)
        if not success:
            raise ValueError("Could not encode image")
        return base64.b64encode(buffer).decode('utf-8')
    except Exception as e:
        print(f"Image encoding error: {e}")
        return ""


def send_post_request(data, endpoint="http://xxx.xxx.xxx.xxx:7410/ocr"):
    """Send OCR request to API service"""
    try:
        headers = {"Content-Type": "application/json"}
        response = requests.post(endpoint, headers=headers, data=json.dumps(data), timeout=300)
        response.raise_for_status()
        return response.json()
    except Exception as e:
        print(f"OCR API request failed: {e}")
        return {"result": ""}


def extract_ocr_text(image_path):
    """Extract text information from image using OCR"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Could not read image: {image_path}")

        b64_img = cv2_to_base64(img)
        payload_img = {"image": [b64_img]}
        result = send_post_request(payload_img)
        ocr_text = result.get('result', "")

        # Parse OCR results
        if ocr_text and isinstance(ocr_text, list) and len(ocr_text) > 0:
            return '\n'.join(ocr_text[0])
        return "No text recognized"
    except Exception as e:
        print(f"OCR extraction error: {e}")
        return "OCR processing failed"

def extract_ocr_text_withdet(image_path):
    """Extract text information from image using OCR with detection"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Could not read image: {image_path}")

        b64_img = cv2_to_base64(img)
        payload_img = {"image": [b64_img]}
        result = send_post_request(payload_img, endpoint="http://xxx.xxx.xxx.xxx:7411/ocr")
        ocr_text = result.get('result', "")

        # Parse OCR results
        if ocr_text and isinstance(ocr_text, list) and len(ocr_text) > 0:
            return '\n'.join([json.dumps(t, ensure_ascii=False)for t in ocr_text])
        return "No text recognized"
    except Exception as e:
        print(f"OCR extraction error: {e}")
        return "OCR processing failed"


def recognize_text(image_path):
    # Initialize reader object (supports English and Chinese by default)
    reader = easyocr.Reader(['ch_sim', 'en'])

    # Read image and recognize text
    results = reader.readtext(image_path)

    # Extract recognition results and join with spaces
    recognized_text = " ".join([result[1] for result in results])

    return recognized_text


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


def load_image_2(image_file, input_size=448, min_num=1, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images, target_aspect_ratio = dynamic_preprocess_2(image, image_size=input_size, use_thumbnail=True,
                                                       min_num=min_num, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values, target_aspect_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # Calculate existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # Find closest aspect ratio to target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # Calculate target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # Resize image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # Split image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def dynamic_preprocess_2(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # Calculate existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # Find closest aspect ratio to target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # Calculate target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # Resize image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # Split image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images, target_aspect_ratio


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


class MultimodalModelWrapper:
    """Wrapper class for unified multimodal model inference output"""

    def __init__(self, model_name: str, img_dir: str, temperature: float = 0.1, lambda_factor: float = 0.5):
        """
        Initialize model wrapper
        Args:
            model_name: Model name (qwenvl2/internvl2.5/minicpm/etc)
            img_dir: Image directory
            temperature: Temperature parameter
            lambda_factor: Lambda factor
        """
        self.temperature = temperature
        self.lambda_factor = lambda_factor
        self.img_dir = img_dir
        self.model_name = model_name
        self.ocr_preds = 'xxx/ocrbench_v2_result.json'
        self.ocr_cache = {}
        ocr_jsonl_path = "xxx/ocrbenchv2_ocr_result.jsonl"
        try:
            with open(ocr_jsonl_path, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        record = json.loads(line.strip())
                        image_path = record.get("image_path")
                        ocr_res = record.get("ocr_res", "")
                        if image_path:
                            self.ocr_cache[image_path] = ocr_res
                    except json.JSONDecodeError:
                        continue  # Skip malformed lines
        except Exception as e:
            print(f"Failed to load OCR JSONL file: {e}")
        self.supported_models = [
            "qwenvl2",
            "internvl2.5",
            "minicpm",
            "llava",
            "qwenvl2.5",  # 7B
            "mplug",
            "minimonkey",
            "monkey",  # bug
            "textmonkey",
            "textharmony",  # bug
            "internvl3",  # 8B
            "qwenvl2.5-32B",
            "qwenvl2.5-3B",
            "internvl3-2B",
            "internvl3-14B",
            "internvl3-38B",
            "kimi",
            "janus",
            "ovis",
            "mimo",
        ]
        if model_name == "qwenvl2":
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                "xxx/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
            )
            self.processor = AutoProcessor.from_pretrained("xxx/Qwen2-VL-7B-Instruct",
                                                           max_pixels=640 * 640)
        elif model_name == 'internvl3':
            path = 'xxx/InternVL3-8B'
            from transformers import AutoModel, AutoTokenizer
            self.model = AutoModel.from_pretrained(
                path,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True).eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
            self.generation_config = dict(max_new_tokens=1024, do_sample=False)
        elif model_name == 'internvl3-2B':
            path = 'xxx/InternVL3-2B'
            from transformers import AutoModel, AutoTokenizer
            self.model = AutoModel.from_pretrained(
                path,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True).eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
            self.generation_config = dict(max_new_tokens=1024, do_sample=False)
        elif model_name == 'internvl3-14B':
            path = 'xxx/InternVL3-14B'
            from transformers import AutoModel, AutoTokenizer
            self.model = AutoModel.from_pretrained(
                path,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True).eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
            self.generation_config = dict(max_new_tokens=1024, do_sample=False)
        elif model_name == 'internvl2.5':
            path = 'xxx/InternVL2_5-8B'
            from transformers import AutoModel, AutoTokenizer
            self.model = AutoModel.from_pretrained(
                path,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True).eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
            self.generation_config = dict(max_new_tokens=1024, do_sample=False)

        elif model_name == 'mplug':
            import sys
            from modelscope import AutoConfig, AutoModel, AutoTokenizer
            sys.path.append('xxx/mPLUG-Owl3-7B-241101/')
            model_path = 'xxx/mPLUG-Owl3-7B-241101/'
            self.config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
            model = AutoModel.from_pretrained(model_path, attn_implementation='flash_attention_2',
                                              torch_dtype=torch.bfloat16, trust_remote_code=True)
            self.model = model.eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.processor = self.model.init_processor(self.tokenizer)

        elif model_name == 'llava':
            import sys
            sys.path.append('xxx/llava-v1.6-mistral-7b-hf')
            model_path = 'xxx/llava-v1.6-mistral-7b-hf'
            from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
            from PIL import Image
            self.processor = LlavaNextProcessor.from_pretrained(model_path)

            model = LlavaNextForConditionalGeneration.from_pretrained(model_path,
                                                                      torch_dtype=torch.float16, low_cpu_mem_usage=True)
            self.model = model.eval().cuda()

        elif model_name == 'monkey':
            from transformers import AutoModelForCausalLM, AutoTokenizer
            checkpoint = "xxx/Monkey"
            self.model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda',
                                                              trust_remote_code=True).eval()
            tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
            tokenizer.padding_side = 'left'
            tokenizer.pad_token_id = tokenizer.eod_id
            self.tokenizer = tokenizer

        elif model_name == 'minimonkey':
            path = 'xxx/MiniMonkey'
            from transformers import AutoModel, AutoTokenizer
            self.model = AutoModel.from_pretrained(
                path,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True).eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
            self.generation_config = dict(max_new_tokens=1024, do_sample=False)

        elif model_name == "textharmony":
            pass

        elif model_name == 'minicpm':
            from transformers import AutoModel, AutoTokenizer
            self.model = AutoModel.from_pretrained(
                'xxx/MiniCPM-o-2_6',
                trust_remote_code=True,
                attn_implementation='sdpa',  # sdpa or flash_attention_2
                torch_dtype=torch.bfloat16,
                init_vision=True,
                init_audio=True,
                init_tts=True
            )

            self.model = self.model.eval().cuda()
            self.tokenizer = AutoTokenizer.from_pretrained('xxx/MiniCPM-o-2_6',
                                                           trust_remote_code=True)

            self.model.init_tts()
        elif model_name == 'qwenvl2.5':
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                "xxx/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
            )
            self.processor = AutoProcessor.from_pretrained("xxx/Qwen2.5-VL-7B-Instruct",
                                                           max_pixels=640 * 640)


        elif model_name == 'qwenvl2.5-3B':
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                "xxx/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
            )
            self.processor = AutoProcessor.from_pretrained("xxx/Qwen2.5-VL-3B-Instruct",
                                                           max_pixels=640 * 640)

        elif model_name == 'qwenvl2.5-32B':
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                "xxx/Qwen2.5-VL-32B-Instruct", torch_dtype="auto", device_map="auto"
            )
            self.processor = AutoProcessor.from_pretrained("xxx/Qwen2.5-VL-32B-Instruct",
                                                           max_pixels=640 * 640)
        if model_name not in self.supported_models:
            raise ValueError(f"Unsupported model type: {model_name}")

    def extract_ocr_text(self, image_path):
        """
        Extract OCR text from cache given image path
        """
        try:
            ocr_text = self.ocr_cache.get(image_path, "")
            return ocr_text if ocr_text else "No text recognized"
        except Exception as e:
            print(f"OCR extraction error: {e}")
            return "OCR processing failed"

    def process_single_prediction(
            self,
            image_path: str,
            question: str,
            answers: List[str],
            dataset_name: str = "rico",
            data_type: str = "APP agent en",
            id: Optional[int] = None,
            sample=None
    ) -> Dict[str, Any]:
        """
        Process prediction for a single sample
        Args:
            image_path: Image path
            question: Question
            answers: List of ground truth answers
            dataset_name: Dataset name
            data_type: Data type
            id: Sample ID
            sample: Sample data

        Returns:
            Formatted prediction dictionary
        """
        # Get raw model prediction
        raw_prediction = self._get_model_prediction(image_path, question, answers, sample)

        # Format output uniformly
        try:
            print(raw_prediction.split('Answer:')[-1])
            prediction = {
                "dataset_name": dataset_name,
                "type": data_type,
                "id": id,
                "image_path": image_path,
                "question": question,
                "answers": answers,
                "predict": raw_prediction.split('Answer:')[-1]
            }
            return prediction
        except:
            print(raw_prediction)
            prediction = {
                "dataset_name": dataset_name,
                "type": data_type,
                "id": id,
                "image_path": image_path,
                "question": question,
                "answers": answers,
                "predict": raw_prediction
            }
            return prediction

    def batch_predict(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Batch process multiple samples
        Args:
            samples: List of samples with required information
        Returns:
            List of prediction results
        """
        predictions = []
        for sample in tqdm(samples):
            pred = self.process_single_prediction(
                image_path=sample["image_path"],
                question=sample["question"],
                answers=sample["answers"],
                dataset_name=sample.get("dataset_name", "rico"),
                data_type=sample.get("type", "APP agent en"),
                id=sample.get("id"),
                sample=sample,
            )
            if 'eval' in sample:
                pred['eval'] = sample['eval']
            predictions.append(pred)
        return predictions

    def save_predictions(self, predictions: List[Dict[str, Any]], output_path: str):
        """
        Save predictions to JSON file
        Args:
            predictions: List of predictions
            output_path: Output file path
        """
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(predictions, f, ensure_ascii=False, indent=4)

    def _get_model_prediction(self, image_path: str, question: str, answers: List[str], sample=None) -> str:
        """
        Get raw model prediction - add format instructions
        Args:
            image_path: Image path
            question: Question
            answers: Option list
        Returns:
            Model prediction result
        """
        # Build formatted options text
        if 'options' in sample.keys():
            # Add explicit output format instructions
            instruction = (
                "Please strictly follow these rules:"
                "Let us think this question step by step (Chain of thought) and output the letter(s) of the correct answer (e.g. 'A' or 'B') finally"
                "Place the answer only the option letter (with no extra characters) on a separate last line."
                f"Question：{question}\n"
                f"Options：\n{sample['options']}\n\n"
                "Chain of thought:\n"
                "Answer："
            )
        else:
            instruction = (
                "Please strictly follow these rules:\n"
                f"Question：{question}\n"
                "Answer:\n"
            )
        # Call corresponding inference method based on model
        if self.model_name == "qwenvl2":
            return self._predict_qwenvl2(image_path, instruction)
        elif self.model_name == "internvl2.5":
            return self._predict_internvl2(image_path, instruction)
        elif self.model_name in ["internvl3", "internvl3-2B", "internvl3-14B"]:
            return self._predict_internvl2(image_path, instruction)
        elif self.model_name == "mplug":
            return self._predict_mplug_owl(image_path, instruction)
        elif self.model_name == 'llava':
            return self._predict_llava_16_ocr(image_path, instruction)
        elif self.model_name == 'monkey':
            return self._predict_monkey(image_path, instruction)
        elif self.model_name == "minimonkey":
            return self._predict_mini_monkey(image_path, instruction)
        elif self.model_name == 'textharmony':
            return self._predict_textharmony(image_path, instruction)
        elif self.model_name == 'minicpm':
            return self._predict_minicpm(image_path, instruction)
        elif self.model_name in ['qwenvl2.5', 'qwenvl2.5-3B', 'qwenvl2.5-32B']:
            return self._predict_qwenvl2_ocr(image_path, instruction)

        raise NotImplementedError(f"Prediction method for model {self.model_name} not implemented")

    def _normalize_prediction(self, raw_prediction: str) -> str:
        """
        Normalize prediction - extract option letters
        Args:
            raw_prediction: Raw prediction result
        Returns:
            Normalized prediction result (option letters)
        """
        # Extract last line as answer line
        lines = raw_prediction.strip().split('\n')
        answer_line = lines[-1].strip() if lines else ""

        # Extract valid option letters (A-Z)
        valid_chars = [char for char in answer_line if char.isalpha() and char.isupper()]
        extracted = ''.join(sorted(set(valid_chars)))  # Deduplicate and sort

        return extracted

    # Model-specific prediction implementations
    def _predict_minicpm(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            image = Image.open(img_path).convert('RGB')
            msgs = [{'role': 'user', 'content': [image, question]}]
            res = self.model.chat(
                image=None,
                msgs=msgs,
                tokenizer=self.tokenizer
            )
        except:
            print(image_path)
            res = ''
        return res

    def _predict_textharmony(self, image_path: str, question: str) -> str:

        images = []
        img_path = os.path.join(self.img_dir, image_path)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        images.append(image)
        image_tensors = np.stack(images, axis=0)

        image_subseq = "<|beginofimage|>" + "<|image|>" * 512

        text = "Based on the image, please answer the question. {image}{question} The answer is:".format(
            image=image_subseq, question=question)

        text = (
            text.replace("<|image|> ", "<|image|>")
                .replace(" <|image|>", "<|image|>")
                .replace(" <|beginofimage|>", "<|beginofimage|>")
                .replace("<|beginofimage|> ", "<|beginofimage|>")
        )

        self.tokenizer.padding_side = "right"
        text_tensor = self.tokenizer(
            text,
            max_length=2048,
            truncation=True,
            padding="do_not_pad",
            return_tensors="np",
            return_attention_mask=True,
        )
        text_ids = text_tensor["input_ids"][0]
        text_attn_mask = text_tensor["attention_mask"][0]

        image_tensors = torch.from_numpy(image_tensors)
        num_images = image_tensors.shape[0]
        target_image_idxs = torch.tensor([num_images - 1], dtype=torch.long)

        task_identifiers = [
            ["Generate an image", "Fill the masked"],
            [""]
        ]
        meta = {}
        meta["task_id"] = None
        for task_id, idents in enumerate(task_identifiers):
            flag = False
            for ident in idents:
                if ident.lower() in text.lower():
                    flag = True
                    break
            if flag:
                meta["task_id"] = task_id
                break
        assert meta["task_id"] is not None

        _data = dict(
            image_tensors=image_tensors,
            image_tensors_dec=None,
            text_ids=torch.from_numpy(text_ids)[None, ...],
            attention_mask=torch.from_numpy(text_attn_mask)[None, ...],
            num_image_per_seq=torch.tensor([num_images]),
            nearest_bos_idxs=None,
            meta=meta,
            target_image_idxs=target_image_idxs,
        )

        inputs = _data
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                v = v.to(device="cuda")
                inputs[k] = v

        try:
            outputs = self.model.generate(mode="generate_texts", **inputs)
            generate_texts = self.tokenizer.batch_decode(
                outputs["text_ids"], skip_special_tokens=True
            )
        except:
            generate_texts = ['']
            print('Error in output')

        return generate_texts

    def _predict_mini_monkey(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            pixel_values, target_aspect_ratio = load_image_2(img_path, max_num=12)
            pixel_values = pixel_values.to(torch.bfloat16).cuda()
            response, history = self.model.chat(self.tokenizer, pixel_values, target_aspect_ratio, question,
                                                self.generation_config, history=None, return_history=True)
        except Exception as e:
            print(e)
            response = ['']
        return response

    def _predict_monkey(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            query = f'<img>{img_path}</img> {question} Answer: '  # VQA

            input_ids = self.tokenizer(query, return_tensors='pt', padding='longest')
            attention_mask = input_ids.attention_mask
            input_ids = input_ids.input_ids

            pred = self.model.generate(
                input_ids=input_ids.cuda(),
                attention_mask=attention_mask.cuda(),
                do_sample=False,
                num_beams=1,
                max_new_tokens=512,
                min_new_tokens=1,
                length_penalty=1,
                num_return_sequences=1,
                output_hidden_states=True,
                use_cache=True,
                pad_token_id=self.tokenizer.eod_id,
                eos_token_id=self.tokenizer.eod_id,
            )
            response = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
        except:
            response = ''
            print(image_path)
        return response

    def _predict_llava_16(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            image = Image.open(img_path)

            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": f"{question}"},
                        {"type": "image"},
                    ],
                },
            ]
            prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)

            inputs = self.processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")

            output = self.model.generate(**inputs, max_new_tokens=512)
            ans = self.processor.decode(output[0], skip_special_tokens=True)
        except:
            print(image_path)
            ans = ''
        return ans

    def _predict_llava_16_ocr(self, image_path: str, question: str, kl_reduction: str = "mean") -> str:
        try:
            import os
            import torch
            import torch.nn.functional as F
            from PIL import Image
            img_path = os.path.join(self.img_dir, image_path)
            ocr_text = extract_ocr_text(img_path)

            model_vocab_size = self.model.config.vocab_size
            pseudo_logits = generate_pseudo_logits(
                self.processor.tokenizer,
                ocr_text,
                model_vocab_size,
                temperature=0.1
            )

            processed_logits = []
            for tensor in pseudo_logits:
                logits_tensor = tensor[0]
                if logits_tensor.shape[0] != model_vocab_size:
                    new_logits = torch.full((model_vocab_size,), -1e10, device="cuda:0")
                    min_size = min(logits_tensor.shape[0], model_vocab_size)
                    new_logits[:min_size] = logits_tensor[:min_size]
                    processed = F.log_softmax(new_logits.unsqueeze(0), dim=-1)
                else:
                    processed = F.log_softmax(logits_tensor.unsqueeze(0), dim=-1)
                processed_logits.append(processed)

            final_question = (
                f"The image contains the following OCR text: '{ocr_text}'. "
                f"Please consider this information when answering. {question}"
            )

            image = Image.open(img_path)
            conversation = [{
                "role": "user",
                "content": [
                    {"type": "text", "text": final_question},
                    {"type": "image"},
                ],
            }]
            prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
            inputs = self.processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")

            class KLGuidedLogitsProcessor:
                def __init__(self, kl_logits, kl_reduction="mean"):
                    self.kl_logits = kl_logits
                    self.kl_reduction = kl_reduction
                    self.step_count = 0

                def __call__(self, input_ids, scores):
                    if self.step_count >= len(self.kl_logits):
                        return scores

                    ocr_log = self.kl_logits[self.step_count].to(scores.device)
                    self.step_count += 1

                    if ocr_log.shape[-1] != scores.shape[-1]:
                        new_ocr_log = torch.full_like(scores, -1e10)
                        min_dim = min(ocr_log.shape[-1], scores.shape[-1])
                        new_ocr_log[..., :min_dim] = ocr_log[..., :min_dim]
                        ocr_log = new_ocr_log

                    p = F.softmax(scores, dim=-1)
                    log_p = torch.log(p + 1e-10)
                    kl = torch.sum(p * (log_p - ocr_log), dim=-1, keepdim=True)

                    guided_logits = scores - 0.5 * kl
                    return guided_logits

            kl_processor = KLGuidedLogitsProcessor(processed_logits, kl_reduction)

            output = self.model.generate(
                **inputs,
                max_new_tokens=512,
                logits_processor=[kl_processor],
                do_sample=True,
                top_p=0.7,
                temperature=0.1
            )

            input_length = inputs.input_ids.shape[1]
            generated_ids = output[:, input_length:]
            ans = self.processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            return ans

        except Exception as e:
            print(f"VDGD generation failed: {e}")
            import traceback
            traceback.print_exc()
            return self._predict_llava_16(image_path, question)

    def _predict_qwenvl2_ocr(self, image_path: str, question: str, kl_reduction: str = "mean") -> str:
        try:
            import os
            import torch
            import torch.nn.functional as F
            from PIL import Image

            img_path = os.path.join(self.img_dir, image_path)
            ocr_text = extract_ocr_text(img_path)

            model_vocab_size = self.model.config.vocab_size
            device = self.model.device
            pseudo_logits = generate_pseudo_logits(
                self.processor.tokenizer,
                ocr_text,
                model_vocab_size,
                temperature=self.temperature
            )

            processed_logits = []
            for tensor in pseudo_logits:
                logits_tensor = tensor[0].to(device)
                if logits_tensor.shape[0] != model_vocab_size:
                    new_logits = torch.full((model_vocab_size,), -1e10, device=device)
                    min_size = min(logits_tensor.shape[0], model_vocab_size)
                    new_logits[:min_size] = logits_tensor[:min_size]
                    processed = F.log_softmax(new_logits.unsqueeze(0), dim=-1)
                else:
                    processed = F.log_softmax(logits_tensor.unsqueeze(0), dim=-1)
                processed_logits.append(processed)

            final_prompt = (
                "The input image is described as follows: Image contains text that needs to be recognized.\n\n"
                "Additionally, the following text was recognized in the image: '{ocr_text}'\n\n"
                "{question}"
            )
            qs = final_prompt.format(ocr_text=ocr_text, question=question)

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": img_path},
                        {"type": "text", "text": qs},
                    ],
                }
            ]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, _ = process_vision_info(messages)
            device = self.model.device
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                padding=True,
                return_tensors="pt"
            ).to(device)

            input_ids = inputs.data['input_ids']
            attention_mask = inputs.data['attention_mask']
            pixel_values = inputs.data['pixel_values']
            image_grid_thw = inputs.data['image_grid_thw']

            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)

            eos_token_id = self.processor.tokenizer.eos_token_id
            pad_token_id = self.processor.tokenizer.pad_token_id or eos_token_id

            generation_config = {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "pixel_values": pixel_values,
                "image_grid_thw": image_grid_thw,
                "max_new_tokens": 1024,
                "do_sample": True,
                "top_p": 0.7,
                "temperature": 0.1,
                "output_scores": True,
                "return_dict_in_generate": True,
                "use_cache": True
            }

            from transformers import LogitsProcessor

            class KLGuidedLogitsProcessor(LogitsProcessor):
                def __init__(self, kl_logits, kl_reduction="avg", lambda_factor=0.1):
                    self.kl_logits = kl_logits
                    self.kl_reduction = kl_reduction
                    self.lambda_factor = lambda_factor

                def __call__(self, input_ids, scores):
                    kl_scores = torch.zeros_like(scores)
                    for ocr_log in self.kl_logits:
                        if ocr_log.shape[-1] != scores.shape[-1]:
                            new_ocr_log = torch.full_like(scores, -1e10)
                            min_dim = min(ocr_log.shape[-1], scores.shape[-1])
                            new_ocr_log[..., :min_dim] = ocr_log[..., :min_dim]
                            ocr_log = new_ocr_log

                        p = F.softmax(scores, dim=-1)
                        log_p = torch.log(p + 1e-10)
                        kl = torch.sum(p * (log_p - ocr_log), dim=-1, keepdim=True)
                        kl_scores += kl

                    if self.kl_reduction == "avg":
                        kl_scores /= len(self.kl_logits)
                    elif self.kl_reduction == "min":
                        kl_scores, _ = torch.min(kl_scores, dim=1, keepdim=True)

                    guided_logits = scores - self.lambda_factor * kl_scores
                    return guided_logits

            kl_processor = KLGuidedLogitsProcessor(processed_logits, kl_reduction, lambda_factor=self.lambda_factor)
            generation_config["logits_processor"] = [kl_processor]

            outputs = self.model.generate(**generation_config)
            generated_ids = outputs.sequences[:, input_ids.shape[1]:]

            response = self.processor.tokenizer.decode(
                generated_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
            return response

        except Exception as e:
            print(f"VDGD generation failed: {e}")
            import traceback
            traceback.print_exc()
            return self._predict_qwenvl2(image_path, question)

    def _predict_qwenvl2_vdgd(self, image_path: str, question: str, kl_reduction: str = "mean") -> str:
        try:
            import os
            import torch
            import torch.nn.functional as F
            from PIL import Image

            img_path = os.path.join(self.img_dir, image_path)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": img_path,
                        },
                        {"type": "text", "text": "Describe this image."},
                    ],
                }
            ]
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")
            generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            caption = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )

            model_vocab_size = self.model.config.vocab_size
            device = self.model.device
            pseudo_logits = generate_pseudo_logits(
                self.processor.tokenizer,
                caption,
                model_vocab_size,
                temperature=0.1
            )

            processed_logits = []
            for tensor in pseudo_logits:
                logits_tensor = tensor[0].to(device)
                if logits_tensor.shape[0] != model_vocab_size:
                    new_logits = torch.full((model_vocab_size,), -1e10, device=device)
                    min_size = min(logits_tensor.shape[0], model_vocab_size)
                    new_logits[:min_size] = logits_tensor[:min_size]
                    processed = F.log_softmax(new_logits.unsqueeze(0), dim=-1)
                else:
                    processed = F.log_softmax(logits_tensor.unsqueeze(0), dim=-1)
                processed_logits.append(processed)

            final_prompt = (
                "The input image is described as follows: Image contains text that needs to be recognized.\n\n"
                "Additionally, the following text was recognized in the image: '{ocr_text}'\n\n"
                "{question}"
            )
            qs = final_prompt.format(ocr_text=caption, question=question)

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": img_path},
                        {"type": "text", "text": qs},
                    ],
                }
            ]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, _ = process_vision_info(messages)
            device = self.model.device
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                padding=True,
                return_tensors="pt"
            ).to(device)

            input_ids = inputs.data['input_ids']
            attention_mask = inputs.data['attention_mask']
            pixel_values = inputs.data['pixel_values']
            image_grid_thw = inputs.data['image_grid_thw']

            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)

            eos_token_id = self.processor.tokenizer.eos_token_id
            pad_token_id = self.processor.tokenizer.pad_token_id or eos_token_id

            generation_config = {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "pixel_values": pixel_values,
                "image_grid_thw": image_grid_thw,
                "max_new_tokens": 1024,
                "do_sample": True,
                "top_p": 0.7,
                "temperature": 0.1,
                "output_scores": True,
                "return_dict_in_generate": True,
                "use_cache": True
            }

            from transformers import LogitsProcessor

            class KLGuidedLogitsProcessor(LogitsProcessor):
                def __init__(self, kl_logits, kl_reduction="avg"):
                    self.kl_logits = kl_logits
                    self.kl_reduction = kl_reduction

                def __call__(self, input_ids, scores):
                    kl_scores = torch.zeros_like(scores)
                    for ocr_log in self.kl_logits:
                        if ocr_log.shape[-1] != scores.shape[-1]:
                            new_ocr_log = torch.full_like(scores, -1e10)
                            min_dim = min(ocr_log.shape[-1], scores.shape[-1])
                            new_ocr_log[..., :min_dim] = ocr_log[..., :min_dim]
                            ocr_log = new_ocr_log

                        p = F.softmax(scores, dim=-1)
                        log_p = torch.log(p + 1e-10)
                        kl = torch.sum(p * (log_p - ocr_log), dim=-1, keepdim=True)
                        kl_scores += kl

                    if self.kl_reduction == "avg":
                        kl_scores /= len(self.kl_logits)
                    elif self.kl_reduction == "min":
                        kl_scores, _ = torch.min(kl_scores, dim=1, keepdim=True)

                    guided_logits = scores - 0.5 * kl_scores
                    return guided_logits

            kl_processor = KLGuidedLogitsProcessor(processed_logits, kl_reduction)
            generation_config["logits_processor"] = [kl_processor]

            outputs = self.model.generate(**generation_config)
            generated_ids = outputs.sequences[:, input_ids.shape[1]:]

            response = self.processor.tokenizer.decode(
                generated_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
            return response

        except Exception as e:
            print(f"VDGD generation failed: {e}")
            import traceback
            traceback.print_exc()
            return self._predict_qwenvl2(image_path, question)

    def _predict_qwenvl2_paddle(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            ocr_text = extract_ocr_text(img_path)
            final_prompt = (
                "The input image is described as follows: Image contains text that needs to be recognized.\n\n"
                "Additionally, the following text was recognized in the image: '{ocr_text}'\n\n"
                "{question}"
            )
            qs = final_prompt.format(ocr_text=ocr_text, question=question)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": img_path,
                        },
                        {"type": "text", "text": f"{qs}"},
                    ],
                }
            ]
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")
            generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
        except:
            print(image_path)
            output_text = ['']

        return output_text[0]

    def _predict_qwenvl2(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": img_path,
                        },
                        {"type": "text", "text": f"{question}"},
                    ],
                }
            ]
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")
            generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
        except:
            print(image_path)
            output_text = ['']

        return output_text[0]

    def _predict_internvl2(self, image_path: str, question: str) -> str:
        try:
            img_path = os.path.join(self.img_dir, image_path)
            pixel_values = load_image(img_path, max_num=12).to(torch.bfloat16).cuda()
            answer = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config, history=None,
                                     return_history=True)
        except:
            print(image_path)
            answer = ['']
        return answer[0]


def generate_pseudo_logits(tokenizer, text, model_vocab_size, temperature=0.7):
    try:
        inputs = tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        )
        input_ids = inputs.input_ids

        batch_size, seq_length = input_ids.shape
        pseudo_logits = torch.full((batch_size, seq_length, model_vocab_size), -10000.0)

        for i in range(batch_size):
            for j in range(seq_length):
                token_id = input_ids[i, j].item()
                if token_id < model_vocab_size:
                    base_value = 5.0 / max(temperature, 0.1)
                    pseudo_logits[i, j, token_id] = base_value

        if temperature > 0:
            noise = torch.randn_like(pseudo_logits) * (0.1 * temperature)
            pseudo_logits += noise

        return pseudo_logits

    except Exception as e:
        print(f"Pseudo logits generation error: {e}")
        return torch.zeros((1, 1, model_vocab_size))