import sys
sys.path.append("the relative path")
import os
import torch
import logging
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer, LlavaForConditionalGeneration, AutoProcessor
import io
from utils.Sample import decode_with_top_k, decode_with_greedy
from utils.utils import *

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

class LLaVA_7B():
    def __init__(self, cfg):
        super().__init__()
        self.device = cfg.device
        self.name = "LLaVA_7B"
        self.temperature = cfg.temperature #0
        """Load model and tokenizer."""
        logging.info("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)
        self.model = LlavaForConditionalGeneration.from_pretrained(cfg.model_path, torch_dtype=torch.float16, device_map=cfg.device, trust_remote_code=True).eval() 
        self.processor = AutoProcessor.from_pretrained(cfg.model_path, trust_remote_code=True)
        logging.info("Model and tokenizer loaded successfully.")
        self.use_image_token = False
        self.prompt = None 

    def generate_prompt(self, prompt = None):
        self.prompt = prompt.replace("<image>", "").strip()
        
        self.prompt = f"<image> {self.prompt}"

        return self.prompt
           


    def get_answer(self,image_path): 
        """Get the confidence and predicted answer."""
        with torch.no_grad():
            if isinstance(image_path, bytes):
                image = Image.open(io.BytesIO(image_path)).convert("RGB")
            else:
                image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, text=self.prompt, return_tensors="pt").to(self.model.device)
            #token_sequence = torch.Tensor(tokenizer.encode('. This answer is wrong. I have to select from [A, B, C, D]: (', add_special_tokens=False)).to(model.device).view(1,-1).to(torch.int)

            outputs = self.model.generate(**inputs, do_sample=True if self.temperature > 0 else False, max_new_tokens=100000)
            # print(outputs)
            description = self.processor.decode(outputs.squeeze(), skip_special_tokens=True)
            mo = self.model(**inputs)

        # print(image_path)
        # print(description, '\n\n')
        description = extract_assistant_answers(description) 
        return description, inputs, mo.logits



    def decode_outputs(self, input_data, temperature = 1):
        input_ids = input_data['input_ids']   
        attention_mask = input_data.get('attention_mask', None)
        pixel_values = input_data.get('pixel_values', None)
        if pixel_values is None: 
            print("pure text output")
        
        generated_tokens = []
        generated_probs = []  
        max_length = 100
        eos_token_id = self.tokenizer.eos_token_id
        
        for _ in range(max_length):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'pixel_values': pixel_values
            }
            with torch.no_grad():
                outputs = self.model(**model_inputs)
            logits = outputs.logits[:, -1, :]
            
            next_token, confidence = decode_with_greedy(
                    logits, 
                    temperature=temperature
            )

            
            if next_token.item() == eos_token_id:
                break
            
            generated_tokens.append(next_token.item())
            generated_probs.append(confidence.item())  
            
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(next_token.unsqueeze(0))],dim=-1)

        decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        jointprob = joint_prob(generated_probs)
        jointlogprob = joint_log_prob(generated_probs)
        return decoded_text, jointprob

    def get_word_nll_and_brier(self, word, input_data):
        input_ids = input_data['input_ids']
        attention_mask = input_data.get('attention_mask', None)
        pixel_values = input_data.get('pixel_values', None)

        word_token_ids = self.tokenizer.encode(word, add_special_tokens=False, return_tensors="pt").to(self.device)
        token_probs, token_nlls, token_briers, token_briers2 = [], [], [],[]

        for i in range(word_token_ids.shape[1]):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'pixel_values': pixel_values,
            }

            with torch.no_grad():
                outputs = self.model(**model_inputs)
                logits = outputs.logits[:, -1, :]  # shape: [1, vocab_size]
                probs = torch.softmax(logits, dim=-1)  # shape: [1, vocab_size]

                token_id = word_token_ids[0, i]
                prob = probs[0, token_id]
                token_probs.append(prob.item())

                nll = -torch.log(prob + 1e-12).item()
                if nll == float('inf'):
                    nll = -torch.log(prob + 1e-7).item() 
                token_nlls.append(nll)
                # token_briers.append(((prob - 1) ** 2).item())
                sum_p2 = torch.sum(probs[0] ** 2)          # ||p||_2^2
                brier_multi = (1.0 - 2.0 * prob + sum_p2).item()
                token_briers.append(brier_multi)
                token_briers2.append((prob - 1).abs().item())

            input_ids = torch.cat([input_ids, token_id.view(1, 1)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(token_id.view(1, 1))], dim=-1)

        tokens_str = [self.tokenizer.decode([tid]) for tid in word_token_ids[0]]
        total_nll = sum(token_nlls)
        avg_brier = sum(token_briers) / len(token_briers)
        avg_brier2 = sum(token_briers2) / len(token_briers2)
        return total_nll, avg_brier, avg_brier2, token_probs, tokens_str

        
class LLaVA_13B():
    def __init__(self, cfg):
        super().__init__()
        self.device = cfg.device
        self.name = "LLaVA_13B"
        self.temperature = cfg.temperature
        """Load model and tokenizer."""
        logging.info("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)
        self.model = LlavaForConditionalGeneration.from_pretrained(cfg.model_path, torch_dtype=torch.float16, device_map=cfg.device, trust_remote_code=True).eval() 
        self.processor = AutoProcessor.from_pretrained(cfg.model_path, trust_remote_code=True)
        logging.info("Model and tokenizer loaded successfully.")
        self.use_image_token = False
        self.prompt = None


    def generate_prompt(self, prompt = None):
        self.prompt = prompt.replace("<image>", "").strip()
        
        self.prompt = f"<image> {self.prompt}"

        return self.prompt

    def get_answer(self,image_path):
        """Get the confidence and predicted answer."""
        with torch.no_grad():
            if isinstance(image_path, bytes):
                image = Image.open(io.BytesIO(image_path)).convert("RGB")
            else:
                image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, text=self.prompt, return_tensors="pt").to(self.model.device)
            #token_sequence = torch.Tensor(tokenizer.encode('. This answer is wrong. I have to select from [A, B, C, D]: (', add_special_tokens=False)).to(model.device).view(1,-1).to(torch.int)

            outputs = self.model.generate(**inputs, max_new_tokens=1000)
            description = self.processor.decode(outputs.squeeze(), skip_special_tokens=True)
            #print('------------------------', description)
            mo = self.model(**inputs)

        # print(image_path)
        # print(description, '\n\n')
        description = extract_assistant_answers(description)
        
        return description, inputs, mo.logits

    def decode_outputs(self, input_data, temperature = 1):

        input_ids = input_data['input_ids']   
        attention_mask = input_data.get('attention_mask', None)
        pixel_values = input_data.get('pixel_values', None)
        if pixel_values is None: 
            print("pure text output")
        
        generated_tokens = []
        generated_probs = []  
        max_length = 100
        eos_token_id = self.tokenizer.eos_token_id
        
        for _ in range(max_length):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'pixel_values': pixel_values
            }
            with torch.no_grad():
                outputs = self.model(**model_inputs)
            logits = outputs.logits[:, -1, :]
            
            NoGreedy = False
            if NoGreedy == True:
                next_token, confidence = decode_with_top_k(
                    logits, 
                    k=50, 
                    temperature=1
                )
            else:
                next_token, confidence = decode_with_greedy(
                    logits, 
                    temperature=temperature
                )

            
            if next_token.item() == eos_token_id:
                break
            
            generated_tokens.append(next_token.item())
            generated_probs.append(confidence.item())  
            
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat(
                    [attention_mask, torch.ones_like(next_token.unsqueeze(0))],
                    dim=-1
                )
 
        decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        jointprob = joint_prob(generated_probs)
        jointlogprob =joint_log_prob(generated_probs)
        return decoded_text, jointprob       

    def get_word_nll_and_brier(self, word, input_data):
        input_ids = input_data['input_ids']
        attention_mask = input_data.get('attention_mask', None)
        pixel_values = input_data.get('pixel_values', None)

        word_token_ids = self.tokenizer.encode(word, add_special_tokens=False, return_tensors="pt").to(self.device)
        token_probs, token_nlls, token_briers, token_briers2 = [], [], [],[]

        for i in range(word_token_ids.shape[1]):
            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'pixel_values': pixel_values,
            }

            with torch.no_grad():
                outputs = self.model(**model_inputs)
                logits = outputs.logits[:, -1, :]  # shape: [1, vocab_size]
                probs = torch.softmax(logits, dim=-1)  # shape: [1, vocab_size]

                token_id = word_token_ids[0, i]
                prob = probs[0, token_id]
                token_probs.append(prob.item())

                nll = -torch.log(prob + 1e-12).item()
                if nll == float('inf'):
                    nll = -torch.log(prob + 1e-7).item() #
                token_nlls.append(nll)

                sum_p2 = torch.sum(probs[0] ** 2)          # ||p||_2^2
                brier_multi = (1.0 - 2.0 * prob + sum_p2).item()
                token_briers.append(brier_multi)
                token_briers2.append((prob - 1).abs().item())

            input_ids = torch.cat([input_ids, token_id.view(1, 1)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, torch.ones_like(token_id.view(1, 1))], dim=-1)

        tokens_str = [self.tokenizer.decode([tid]) for tid in word_token_ids[0]]
        total_nll = sum(token_nlls)
        avg_brier = sum(token_briers) / len(token_briers)
        avg_brier2 = sum(token_briers2) / len(token_briers2)
        return total_nll, avg_brier, avg_brier2, token_probs, tokens_str

