import sys
sys.path.append("the relative path")
import argparse
import os
import torch
import logging
from PIL import Image
import patch_cache 
from transformers import AutoTokenizer, AutoModel,AutoProcessor
from modelscope import AutoConfig, AutoModel
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
import io
from utils.utils import *
from utils.Sample import decode_with_top_k, logits_processor_decode,decode_with_greedy,decode_with_sample
import re
import copy
import torchvision.transforms as T
import time
import torch.nn.functional as F
import math

class mPLUG():
    def __init__(self, cfg): 
        super().__init__()
        short_ver = re.search(r'-(\d+[A-Za-z]+)-', cfg.model_path)  
        self.name = f"mPLUG_{short_ver.group(1)}"
        self.device = cfg.device
        self.temperature = cfg.temperature #0
        """Load model and tokenizer."""
        logging.info("Loading model and tokenizer...")
        config = AutoConfig.from_pretrained(cfg.model_path, trust_remote_code=True)
        # print(config)
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True) # , use_fast=False
        self.model =AutoModel.from_pretrained(cfg.model_path, attn_implementation='sdpa', torch_dtype=torch.bfloat16, trust_remote_code=True)
        self.model.eval().cuda()
        self.processor = self.model.init_processor(self.tokenizer)
        logging.info("Model and tokenizer loaded successfully.")
        self.use_image_token = False
        self.prompt = None  

    def generate_prompt(self, prompt = None):
        prompt_cleaned = re.sub(r"(<\|image\|>\s*)+", "<|image|>\n", prompt)
        self.prompt = prompt_cleaned
        return self.prompt


    def save_temp_image(self, image, image_path):
        image.save(image_path, format="JPEG")
    
    def get_answer(self, image_path, max_retries=3, retry_delay=1):
        if isinstance(image_path, bytes):
            image = Image.open(io.BytesIO(image_path)).convert("RGB")
        else:
            image = Image.open(image_path).convert("RGB")
        image = image.resize((512, 384))
        
        prompts = [
            {"role": "user", "content": f"<|image|>{self.prompt.strip()}"},
            {"role": "assistant", "content": ""}
        ]
        inputs = self.processor(prompts, images=[image], video=None).to(self.device)
        # print("prompts", prompts)

        # ===== inference: use clean copy =====
        inference_inputs = copy.deepcopy(inputs)
        inference_inputs.update({'return_dict': True})
        with torch.no_grad():
            mo = self.model(**inference_inputs).logits

        # ===== generation with retry =====
        generate_inputs = copy.deepcopy(inputs)
        generate_inputs.update({
            'tokenizer': self.tokenizer,
            'max_new_tokens': 1000,
            'decode_text': True
        })

        last_exception = None
        for attempt in range(1, max_retries + 1):
            try:
                with torch.no_grad():
                    response = self.model.generate(**generate_inputs)
                return response[0], inference_inputs, mo
                # return response[0], 0, 1
            except Exception as e:
                print(f"[Retry {attempt}/{max_retries}] Generation failed: {e}")
                last_exception = e
                time.sleep(retry_delay)
                return None
    
    def decode_outputs(self, input_data, temperature = 1):
        input_ids = input_data['input_ids']   
        media_offset = input_data.get('media_offset', None)
        pixel_values = input_data.get('pixel_values', None)
        return_dict = input_data.get('return_dict', None)
        if pixel_values is None: 
            print("pure text output")
        
        generated_tokens = []
        generated_probs = []  
        max_length = 100
        eos_token_id = self.tokenizer.eos_token_id
        
        for _ in range(max_length):
            model_inputs = {
                'input_ids': input_ids,
                'media_offset': media_offset,
                'pixel_values': pixel_values,
                'return_dict':return_dict,
            }
           
            try:
                with torch.no_grad():
                    outputs = self.model(**model_inputs)
            except AssertionError as e:
                print(f"[ERROR] AssertionError {e}")
                return None, None

            logits = outputs.logits[:, -1, :]
            
            next_token, confidence = decode_with_greedy(
                    logits, 
                    temperature=temperature
            )

            
            if next_token.item() == eos_token_id:
                break
            
            generated_tokens.append(next_token.item())
            generated_probs.append(confidence.item())  
            
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            if media_offset is not None:
                pad_off = torch.tensor(
                    [[[0, 7]]],                        # # use [0, 7] as placeholder, shape [1, 1, 2]
                    dtype=media_offset.dtype,
                    device=media_offset.device,
                ).expand(media_offset.shape[0], -1, -1)  # to batch size

                media_offset = torch.cat([media_offset, pad_off], dim=1) 
        
        # 返回生成文本和概率列表
        decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        jointprob = joint_prob(generated_probs)
        jointlogprob = joint_log_prob(generated_probs)
        return decoded_text, jointprob

    def get_word_nll_and_brier(self, word, input_data):
        input_ids = input_data['input_ids']
        media_offset = input_data.get('media_offset', None)
        pixel_values = input_data.get('pixel_values', None)
        return_dict = input_data.get('return_dict', None)

        word_token_ids = self.tokenizer.encode(word, add_special_tokens=False, return_tensors="pt").to(self.device)
        token_probs = []
        token_nlls = []
        token_briers = []
        token_briers2 = []

        for i in range(word_token_ids.shape[1]):
            model_inputs = {
                'input_ids': input_ids,
                'media_offset': media_offset,
                'pixel_values': pixel_values,
                'return_dict': return_dict,
            }

            with torch.no_grad():
                outputs = self.model(**model_inputs)
                logits = outputs.logits[:, -1, :]  # shape: [1, vocab_size]
                probs = torch.softmax(logits, dim=-1)  # shape: [1, vocab_size]

                token_id = word_token_ids[0, i]
                prob = probs[0, token_id]
                token_probs.append(prob.item())
                token_nlls.append(-torch.log(prob + 1e-12).item())

                # brier = (prob - 1).pow(2).item()
                # token_briers.append(brier)

                sum_p2 = torch.sum(probs[0] ** 2)          # ||p||_2^2
                brier_multi = (1.0 - 2.0 * prob + sum_p2).item()
                token_briers.append(brier_multi)
                token_briers2.append((prob - 1).pow(1).abs().item())

            # update inputs
            input_ids = torch.cat([input_ids, token_id.view(1, 1)], dim=-1)
            if media_offset is not None:
                pad_off = torch.tensor(
                    [[[0, 7]]], dtype=media_offset.dtype, device=media_offset.device
                ).expand(media_offset.shape[0], -1, -1)
                media_offset = torch.cat([media_offset, pad_off], dim=1)

        total_nll = sum(token_nlls)
        avg_brier = sum(token_briers) / len(token_briers)
        avg_brier2 = sum(token_briers2) / len(token_briers2)
        tokens_str = [self.tokenizer.decode([tid]) for tid in word_token_ids[0]]

        return total_nll, avg_brier, avg_brier2, token_probs, tokens_str