import argparse
import json
import os
import torch
import logging
import numpy as np
from PIL import Image
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, LlavaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor, BlipForConditionalGeneration
import csv
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
from qwen_vl_utils import process_vision_info
import io
from io import BytesIO
# import uuid
import re
import inflect
import hashlib
from utils.utils import *


class Qwen2VL():
    def __init__(self, cfg):
        super().__init__()
        self.device = cfg.device
        self.name = "Qwen2VL"
        """Load model and tokenizer."""
        logging.info("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(cfg.model_path, torch_dtype=torch.float16, device_map=cfg.device).eval()   # 1
        self.processor = AutoProcessor.from_pretrained(cfg.model_path)
        logging.info("Model and tokenizer loaded successfully.")
        self.temp_image_dir = "./datasets/figures"
        self.use_image_token = True
        self.prompt = None
        if not os.path.exists(self.temp_image_dir):
            os.makedirs(self.temp_image_dir)

    def is_image_normal(self,file_path):
        """Check if image file is valid."""
        if not os.path.isfile(file_path) or os.path.getsize(file_path) == 0:
            logging.warning(f"File {file_path} is empty or does not exist.")
            return False

        try:
            with Image.open(file_path) as img:
                img.verify()  # Verify if image is corrupted
                return True
        except (IOError, SyntaxError) as e:
            logging.error(f"Error opening image {file_path}: {e}")
            return False


    def generate_prompt(self, prompt = None):
        self.prompt = prompt.replace("<image>", "").strip()
        
        self.prompt = f"<image> {self.prompt}"

        return self.prompt

    def get_inputs(self,image_path):
        messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},  
                {"type": "text", "text": self.prompt}, 
            ],
        }
    ]
        all_contents = []
        for msg in messages:
            all_contents.extend(msg['content'])

        text = self.processor.apply_chat_template(
            all_contents,
            tokenize=False,
            add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        return inputs

    def save_temp_image(self, image, image_path):
        image.save(image_path, format="JPEG")

    def get_answer(self, image_path):
        if isinstance(image_path, bytes):
            image = Image.open(io.BytesIO(image_path)).convert("RGB")
            max_size = 256
            image.thumbnail((max_size, max_size))
            image_hash = hashlib.md5(image_path).hexdigest()
            image_path_str = f"bytes_{image_hash}.jpg"

        else:
            image = Image.open(image_path).convert("RGB")
            max_size = 256
            image.thumbnail((max_size, max_size))
            image_path_str = image_path
        
        if isinstance(image_path, bytes):
            temp_image_path = os.path.join(self.temp_image_dir, image_path_str)
            if not os.path.exists(temp_image_path):
                self.save_temp_image(image, temp_image_path)
        else:
            temp_image_path = os.path.join(image_path)
            
        with torch.no_grad():
            inputs = self.get_inputs(temp_image_path)
            inputs = inputs.to(self.device)

            generated_ids = self.model.generate(**inputs, max_new_tokens = 640)
            output = self.processor.batch_decode(generated_ids, skip_special_tokens = True, clean_up_tokenization_spaces=False)
            mo = self.model(**inputs)

        return extract_Qwen_assistant_answers(output[0]), inputs, mo.logits
    

    def decode_outputs(self, input_data, temperature=1.0, do_sample=True, top_k=0, top_p=0.3, max_new_tokens=512):
        '''
        使用自定义采样策略生成文本，并返回生成文本和 token 概率列表(从QwenVLChat直接复制的)

        Parameters:
            input_data (dict): input_ids、attention_mask、token_type_ids
            temperature (float): softmax 温度
            do_sample (bool): 是否使用采样策略(top-k/top-p), False 则贪婪解码
            top_k (int): top-k 采样
            top_p (float): top-p 采样
            max_new_tokens (int): 最大生成长度
            eos_token_id (int): 结束 token 的 id

        Returns:
            (str, List[float]): 解码文本 + 各 token 概率
        '''
        
        input_ids = input_data['input_ids']
        attention_mask = input_data['attention_mask']
        pixel_values = input_data['pixel_values']
        image_grid_thw = input_data['image_grid_thw']

        generated_tokens = []
        generated_probs = []
        eos_token = "<|endoftext|>"
        eos_token_id = self.tokenizer.convert_tokens_to_ids(eos_token)

        for _ in range(max_new_tokens):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'pixel_values': pixel_values,
                'image_grid_thw': image_grid_thw
            }

            with torch.no_grad():
                outputs = self.model(**model_inputs)
            logits = outputs.logits[:, -1, :] 
            next_token, confidence = self.decode_with_top_p(logits, top_k=top_k, top_p=top_p, temperature=temperature)


            if eos_token_id is not None and next_token.item() == eos_token_id:
                break

            generated_tokens.append(next_token.item())
            generated_probs.append(confidence.item())

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(next_token.unsqueeze(0))], dim=-1)
                

        decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        jointprob = joint_prob(generated_probs)
        jointlogprob = joint_log_prob(generated_probs)

        return decoded_text, jointprob

    def decode_with_top_p(self, logits, top_k=0, top_p=0.9, temperature=1.0):
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)

        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # top-p 
        sorted_mask = cumulative_probs <= top_p
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = True 

        filtered_probs = sorted_probs * sorted_mask
        filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True)

        # multinomial 
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        token_index = sorted_indices.gather(-1, next_token)

        confidence = filtered_probs.gather(-1, next_token)

        return token_index.squeeze(-1), confidence.squeeze(-1)