import torch
# from transformers import AutoProcessor
from models.llava import LlavaForConditionalGeneration, LlavaProcessor
from models.llama.modeling_llama import replace_llama_modality_adaptive
from baukit import TraceDict
from functools import partial
import json
import os

from PIL import Image

from models.base import Mllm

class Llava(Mllm):
    
    def __init__(self, model_name_or_path, **kwargs):
        replace_llama_modality_adaptive()
        self.model = LlavaForConditionalGeneration.from_pretrained(model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
        self.processor = LlavaProcessor.from_pretrained(model_name_or_path)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.path = model_name_or_path
        # self.model.to(self.device)
    
    def reset(self):
        self.model = None
        del self.model
        self.processor = None
        del self.processor
        torch.cuda.empty_cache()
        
        self.model = LlavaForConditionalGeneration.from_pretrained(self.path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
        self.processor = LlavaProcessor.from_pretrained(self.path)
        
    def evaluate(self, prompt, filepath):
        image = Image.open(filepath)
        
        # ## blank prompt
        # prompt = f"<image>\nASSISTANT:"
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"

        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        generate_ids = self.model.generate(**inputs, max_length=128)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_of_caption(self, prompt, filepath, caption):
        image = Image.open(filepath)
        
        # ## blank prompt
        # prompt = f"<image>\nASSISTANT:"
        
        prompt = f"USER: The given image depicts the following scene: {caption}\n \
Please directly answer the following question from the image description, without guessing or reasoning. Question: \
{prompt}\nASSISTANT:"
        # prompt = f"USER:The given image depicts the following scene: {caption} {prompt}\nASSISTANT:"

        inputs = self.processor(text=prompt, return_tensors='pt').to(self.device)
        # generate_ids = self.model.generate(**inputs, max_length=128)
        generate_ids = self.model.generate(**inputs, max_new_tokens=32)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_of_caption_img(self, prompt, filepath, caption):
        image = Image.open(filepath)
        
        # ## blank prompt
        # prompt = f"<image>\nASSISTANT:"
        
        # prompt = f"USER: The given image depicts the following scene: {caption} {prompt}\nASSISTANT:"
        prompt = f"<image>\nUSER: The given image describe that {caption} {prompt}\nASSISTANT:"

        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        # generate_ids = self.model.generate(**inputs, max_length=128)
        generate_ids = self.model.generate(**inputs, max_new_tokens=32)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output

    
    def evaluate_with_intervention(self, prompt, filepath, interventions, intervention_fn):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output

        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            intervene = partial(intervention_fn, start_edit_location='lt')
        layers_to_intervene = list(interventions.keys())
        
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        
        # ## blank prompt
        # prompt = f"<image>\nASSISTANT:"
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(**inputs, max_length=128)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_with_i2t(self, prompt, filepath, intervention):    
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        generate_ids = self.model.generate(**inputs, max_length=128, intervention=intervention)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output
    
    def evaluate_with_multiple_intervention(self, prompt, filepath, interventions, intervention_fn=None):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output
     
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        ## end of image
        special_image_token_mask = inputs['input_ids'] == self.model.config.image_token_index
        idx_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - special_image_token_mask.shape[-1]
        
        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            # intervene = partial(intervention_fn, special_tokens_location={'img':idx_special_image_tokens.item(), 'lt':-1})
            intervene = partial(intervention_fn, special_tokens_location={'img':slice(1, idx_special_image_tokens.item()), 'lt':-1})
        layers_to_intervene = list(interventions.keys())
        
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(**inputs, max_length=128, )
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output

    def evaluate_with_multiple_intervention2(self, prompt, filepath, interventions, intervention_fn=None):
        # --- intervention code --- #
        def id(head_output, layer_name): 
            return head_output
     
        image = Image.open(filepath)
        
        # # blank image
        # width, height = image.size
        # image = Image.new("RGB", (width, height), (255, 255, 255))
        
        prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        ## end of image
        special_image_token_mask = inputs['input_ids'] == self.model.config.image_token_index
        idx_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - special_image_token_mask.shape[-1]
        
        if interventions == {}: 
            intervene = id
            layers_to_intervene = []
        else: 
            # intervene = partial(intervention_fn, special_tokens_location={'img':idx_special_image_tokens.item(), 'lt':-1})
            intervene = partial(intervention_fn, special_tokens_location={'img':slice(1, idx_special_image_tokens.item()), 'lt':-1})
        layers_to_intervene = list(interventions.keys())
        
        with TraceDict(self.model, layers_to_intervene, edit_output=intervene) as ret: 
            generate_ids = self.model.generate(**inputs, max_length=128)
        output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # post process
        replace_text = 'ASSISTANT: '
        output = output[output.find(replace_text) + len(replace_text):]
        return output



    def get_activations(self, prompt, filepath):
        HEADS = [f"language_model.model.layers.{i}.self_attn.head_out" for i in range(self.model.language_model.config.num_hidden_layers)]
        MLPS = [f"language_model.model.layers.{i}.mlp" for i in range(self.model.language_model.config.num_hidden_layers)]

        image = Image.open(filepath)
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
        
        
        ## end of image
        special_image_token_mask = inputs['input_ids'] == self.model.config.image_token_index
        idx_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - special_image_token_mask.shape[-1]
        
        with torch.no_grad():
            with TraceDict(self.model, HEADS+MLPS) as ret:
                output = self.model(inputs['input_ids'], inputs['pixel_values'], output_hidden_states = True)
            hidden_states = output.hidden_states
            hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
            hidden_states = hidden_states.detach().cpu().numpy()
            head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
            head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
            mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS]
            mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim = 0).squeeze().numpy()

        return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states, [idx_special_image_tokens.item()]
    
    def get_activations_only_text(self, prompt):
        HEADS = [f"language_model.model.layers.{i}.self_attn.head_out" for i in range(self.model.language_model.config.num_hidden_layers)]
        MLPS = [f"language_model.model.layers.{i}.mlp" for i in range(self.model.language_model.config.num_hidden_layers)]

        inputs = self.processor(text=prompt, return_tensors='pt').to(self.device)
                
        with torch.no_grad():
            with TraceDict(self.model, HEADS+MLPS) as ret:
                output = self.model(inputs['input_ids'], output_hidden_states = True)
            hidden_states = output.hidden_states
            hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
            hidden_states = hidden_states.detach().cpu().numpy()
            head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
            head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
            mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS]
            mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim = 0).squeeze().numpy()

        return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states
    
    def get_projected_activations(self, prompt, filepath):

        image = Image.open(filepath)
        inputs = self.processor(text=prompt, images=image, return_tensors='pt').to(self.device)
                
        with torch.no_grad():
            img_activation, text_activation = self.model.get_projected_activation(inputs['input_ids'], inputs['pixel_values'])

        return img_activation.detach().cpu().numpy(), text_activation.detach().cpu().numpy(), inputs['input_ids']
    