import sys
sys.path.append("the relative path")
import argparse
import json
import os
import torch
import logging
import numpy as np
from PIL import Image
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, LlavaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor, BlipForConditionalGeneration
import io
from utils.Sample import decode_with_top_k, logits_processor_decode,decode_with_greedy
from utils.utils import *
from llava.mm_utils import tokenizer_image_token
# model_path = "liuhaotian/llava-v1.5-7b"
 
class LLaVA_7B_lora():
    def __init__(self, cfg):
        super().__init__()
        self.device = cfg.device
        version = os.path.basename(cfg.model_path)
        self.name = version
        self.temperature = cfg.temperature
        logging.info("Loading model and tokenizer...")
        self.model_path = cfg.model_path
        self.base_path = "the local path of llava"
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_path, trust_remote_code=True)
        self.model = LlavaForConditionalGeneration.from_pretrained(
            self.model_path,
            torch_dtype=torch.float16,
            device_map=self.device,
            trust_remote_code=True
        )
        # self.model.resize_token_embeddings(len(self.tokenizer))
        self.processor = AutoProcessor.from_pretrained(
            self.base_path,
            trust_remote_code=True
        )
        self.model.eval()
        logging.info("Model and tokenizer loaded successfully.")
        self.use_image_token = False
        self.prompt = None 
        self.max_new_tokens = 1000


    def generate_prompt(self, prompt = None):
        self.prompt = prompt.replace("<image>", "").strip()
    
        self.prompt = f"<image> {self.prompt}"
        return self.prompt
           

    def get_answer(self, image_path):
        with torch.no_grad():
            if isinstance(image_path, bytes):
                image = Image.open(io.BytesIO(image_path)).convert("RGB")
            else:
                image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, text=self.prompt, return_tensors="pt")
            for tk in inputs.keys():
                inputs[tk] = inputs[tk].to(self.model.device)
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=2000,
            )
            description = self.processor.decode(outputs.squeeze(), skip_special_tokens=True)

            # answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            mo = self.model(**inputs)
            description = extract_assistant_answers(description)
        
        return description, inputs, mo.logits