import argparse
import json
import os
import torch
torch.manual_seed(42)
import logging
import numpy as np
from PIL import Image
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, LlavaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor, BlipForConditionalGeneration
import csv
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
from qwen_vl_utils import process_vision_info
import io
import base64
from io import BytesIO
import uuid
from utils.Sample import decode_with_top_k, logits_processor_decode,decode_with_greedy,decode_with_sample
from utils.utils import *
import hashlib

class QwenVLChat():
    def __init__(self, cfg):
        super().__init__()
        self.name = "QwenVLChat"
        self.device = cfg.device
        """Load model and tokenizer."""
        logging.info("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_path, torch_dtype=torch.float16, device_map=cfg.device,trust_remote_code=True).eval()  # 1
        self.processor = AutoProcessor.from_pretrained(cfg.model_path, trust_remote_code=True)
        logging.info("Model and tokenizer loaded successfully.")
        self.temp_image_dir = "./datasets/figures"
        self.use_image_token = True
        self.prompt = None  # 对话状态
        if not os.path.exists(self.temp_image_dir):
            os.makedirs(self.temp_image_dir)

    def generate_prompt(self, prompt = None):
        self.prompt = prompt.replace("<image>", "").strip()
        
        self.prompt = f"<image> {self.prompt}"

        return self.prompt

    def save_temp_image(self, image, image_path):
        image.save(image_path, format="JPEG")

    def get_answer(self, image_path):
        if isinstance(image_path, bytes):
            image = Image.open(io.BytesIO(image_path)).convert("RGB")
            image_hash = hashlib.md5(image_path).hexdigest()
            image_path_str = f"bytes_{image_hash}.jpg"
            temp_image_path = os.path.join(self.temp_image_dir, image_path_str)
            if not os.path.exists(temp_image_path):
                self.save_temp_image(image, temp_image_path)
        else:
            image = Image.open(image_path).convert("RGB")
            temp_image_path = image_path
            
        with torch.no_grad():
            query = self.tokenizer.from_list_format([
                {'image': temp_image_path},
                {'text': self.prompt},
            ])
            response, history = self.model.chat(self.tokenizer, query=query, history=None)
            inputs = self.processor(images=image, text=query, return_tensors="pt").to(self.model.device)
            mo = self.model(**inputs)

        return response, inputs, mo.logits
    
    def decode_outputs(self, input_data, temperature=1.0, do_sample=True, top_k=0, top_p=0.3, max_new_tokens=21):
        input_ids = input_data['input_ids']   
        attention_mask = input_data.get('attention_mask', None)
        token_type_ids = input_data.get('token_type_ids', None)

        generated_tokens = []
        generated_probs = []
        eos_token = "<|endoftext|>"
        # eos_token = "<|im_end|>"
        eos_token_id = []
        # eos_token_id.append(self.tokenizer.convert_tokens_to_ids(eos_token))
        # eos_token_id.append([[151645], [151644]])
        eos_token_id = [
                        self.tokenizer.convert_tokens_to_ids(eos_token),
                        151645,
                        151644
                    ]


        for _ in range(max_new_tokens):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids
            }
            with torch.no_grad():
                outputs = self.model(**model_inputs)
            logits = outputs.logits[:, -1, :]

            # next_token, confidence = decode_with_top_k(logits, top_k=top_k, top_p=top_p, temperature=temperature)
            next_token, confidence = decode_with_greedy(logits, temperature=temperature)

            if eos_token_id is not None and next_token.item() in eos_token_id:
                break

            generated_tokens.append(next_token.item())
            generated_probs.append(confidence.item())

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(next_token.unsqueeze(0))], dim=-1)
            if token_type_ids is not None:
                token_type_ids = torch.cat([token_type_ids, torch.zeros_like(next_token.unsqueeze(0))], dim=-1)

        decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        jointprob = joint_prob(generated_probs)
        jointlogprob = joint_log_prob(generated_probs)

        return decoded_text, jointprob

    def get_word_nll_and_brier(self, word, input_data):
        input_ids = input_data['input_ids']
        attention_mask = input_data.get('attention_mask', None)
        token_type_ids = input_data.get('token_type_ids', None)

        word_token_ids = self.tokenizer.encode(word, add_special_tokens=False, return_tensors="pt").to(self.device)
        token_probs, token_nlls, token_briers, token_briers2 = [], [], [], []

        for i in range(word_token_ids.shape[1]):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids
            }

            with torch.no_grad():
                outputs = self.model(**model_inputs)
                logits = outputs.logits[:, -1, :]  # shape: [1, vocab_size]
                probs = torch.softmax(logits, dim=-1)  # shape: [1, vocab_size]

                token_id = word_token_ids[0, i]
                prob = probs[0, token_id]
                token_probs.append(prob.item())

                token_nlls.append(-torch.log(prob + 1e-12).item())
                # token_briers.append(((prob - 1) ** 2).item())
                sum_p2 = torch.sum(probs[0] ** 2)          # ||p||_2^2
                brier_multi = (1.0 - 2.0 * prob + sum_p2).item()
                token_briers.append(brier_multi)
                token_briers2.append((prob - 1).abs().item())

            input_ids = torch.cat([input_ids, token_id.view(1, 1)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(token_id.view(1, 1))], dim=-1)
            if token_type_ids is not None:
                token_type_ids = torch.cat([token_type_ids, torch.zeros_like(token_id.view(1, 1))], dim=-1)

        total_nll = sum(token_nlls)
        avg_brier = sum(token_briers) / len(token_briers)
        avg_brier2 = sum(token_briers2) / len(token_briers2)
        tokens_str = [self.tokenizer.decode([tid]) for tid in word_token_ids[0]]

        return total_nll, avg_brier, avg_brier2, token_probs, tokens_str


