from transformers.cache_utils import DynamicCache
if not hasattr(DynamicCache, "get_max_length"):
    DynamicCache.get_max_length = lambda self: getattr(self, "seqlen", None)
import sys
sys.path.append("the relative path")
import argparse
import json
import os
import torch
torch.manual_seed(42)
import logging
import numpy as np
from PIL import Image
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, LlavaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor, BlipForConditionalGeneration
import csv
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
from qwen_vl_utils import process_vision_info
import io
import base64
from io import BytesIO
import uuid
from utils.Sample import decode_with_top_k, logits_processor_decode,decode_with_greedy
from utils.utils import *
from peft import AutoPeftModelForCausalLM, PeftModel, PeftModelForCausalLM
import re
import hashlib
import copy
import time

class mPLUGLora():
    def __init__(self, cfg):
        super().__init__()
        match = re.search(r'/([^/]+)-\d{8}-\d{6}/checkpoint-\d+', cfg.model_path)
        self.name = f"mPLUGLora_{match.group(1)}"
        self.device = cfg.device
        self.lora_path = cfg.model_path     # lora path
        self.base_path = "mPLUG-Owl path" 
        self.tokenizer  = AutoTokenizer.from_pretrained(self.base_path,  trust_remote_code=True)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        base_model = AutoModelForCausalLM.from_pretrained(
            self.base_path,
            torch_dtype=torch.bfloat16,     # or .float16
            trust_remote_code=True).to(self.device)

        # 2. inject LoRA weight
        self.model = PeftModel.from_pretrained(base_model, self.lora_path)
        self.model = self.model.merge_and_unload()    
        self.model.eval()
        self.processor =self.model.init_processor(self.tokenizer)
        logging.info("LoRA‑model finish download")

        self.temp_image_dir = "./datasets/figures"
        os.makedirs(self.temp_image_dir, exist_ok=True)
        self.prompt = None
        
    def generate_prompt(self, prompt = None):
        self.prompt = prompt.replace("<image>", "").strip()
        
        self.prompt = f"<image> {self.prompt}"

        return self.prompt

    def save_temp_image(self, image, image_path):
        image.save(image_path, format="JPEG")

    def get_answer(self, image_path, max_retries=3, retry_delay=1):
        if isinstance(image_path, bytes):
            image = Image.open(io.BytesIO(image_path)).convert("RGB")
            image_hash = hashlib.md5(image_path).hexdigest()
            image_path_str = f"bytes_{image_hash}.jpg"
            temp_image_path = os.path.join(self.temp_image_dir, image_path_str)
            if not os.path.exists(temp_image_path):
                self.save_temp_image(image, temp_image_path)
        else:
            image = Image.open(image_path).convert("RGB")
            temp_image_path = image_path
            
        prompts = [
            {"role": "user", "content": f"""<|image|>
        {self.prompt}"""},
            {"role": "assistant", "content": ""}
        ]

        inputs = self.processor(prompts,images=[image],video=None).to(self.device)
        inference_inputs = copy.deepcopy(inputs)
        inference_inputs.update({'return_dict': True})
        with torch.no_grad():
            mo = self.model(**inference_inputs).logits

        # ===== generation with retry =====
        generate_inputs = copy.deepcopy(inputs)
        generate_inputs.update({
            'tokenizer': self.tokenizer,
            'max_new_tokens': 1000,
            'decode_text': True
        })

        last_exception = None
        for attempt in range(1, max_retries + 1):
            try:
                with torch.no_grad():
                    response = self.model.generate(**generate_inputs)
                return response[0], inference_inputs, mo
                # return response[0], 0, 1
            except Exception as e:
                print(f"[Retry {attempt}/{max_retries}] Generation failed: {e}")
                last_exception = e
                time.sleep(retry_delay)
                return None
    
    def decode_outputs(self, input_data, temperature=1.0, do_sample=True, top_k=0, top_p=0.3, max_new_tokens=512):
        input_ids = input_data['input_ids']   
        attention_mask = input_data.get('attention_mask', None)
        token_type_ids = input_data.get('token_type_ids', None)

        generated_tokens = []
        generated_probs = []
        eos_token = "<|endoftext|>"
        eos_token_id = self.tokenizer.convert_tokens_to_ids(eos_token)
        # print(eos_token_id,eos_token_id)

        for _ in range(max_new_tokens):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                # 'token_type_ids': token_type_ids
            }
            with torch.no_grad():
                outputs = self.model(**model_inputs)
            logits = outputs.logits[:, -1, :]

            next_token, confidence = self.decode_with_top_p(logits, top_k=top_k, top_p=top_p, temperature=temperature)


            if eos_token_id is not None and next_token.item() == eos_token_id:
                break

            generated_tokens.append(next_token.item())
            generated_probs.append(confidence.item())

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(next_token.unsqueeze(0))], dim=-1)
            if token_type_ids is not None:
                token_type_ids = torch.cat([token_type_ids, torch.zeros_like(next_token.unsqueeze(0))], dim=-1)

        decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        jointprob = joint_prob(generated_probs)
        jointlogprob = joint_log_prob(generated_probs)

        return decoded_text, jointprob

    def decode_with_top_p(self, logits, top_k=0, top_p=0.9, temperature=1.0):
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)

        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # top-p
        sorted_mask = cumulative_probs <= top_p
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = True 
        filtered_probs = sorted_probs * sorted_mask
        filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True)

        # multinomial
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        token_index = sorted_indices.gather(-1, next_token)

        confidence = filtered_probs.gather(-1, next_token)

        return token_index.squeeze(-1), confidence.squeeze(-1)
