import torch
from PIL import Image
from ..base import BaseModel
from ...smp import *
from .save_attention_features import *

IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'


class Cambrian(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_path='checkpoints/cambrian-8b', **kwargs):
        assert model_path is not None
        try:
            from .conversation import conv_templates
            from .model.builder import load_pretrained_model
            from .mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
        except Exception as e:
            logging.critical('Please install cambrian from https://github.com/cambrian-mllm/cambrian.')
            raise e

        model_name = get_model_name_from_path(model_path)
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path,
            None,
            model_name,
            device_map='auto'
        )

        if '8b' in model_path:
            self.conv_mode = 'llama_3'
        elif '13b' in model_path:
            self.conv_mode = 'vicuna_v1'
        else:
            self.conv_mode = 'chatml_direct'

        self.model_config = model.config
        self.conv_templates = conv_templates
        self.tokenizer_image_token = tokenizer_image_token
        self.process_images = process_images

        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model = model

    def process(self, image, question):
        if self.model_config.mm_use_im_start_end:
            question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
        else:
            question = DEFAULT_IMAGE_TOKEN + '\n' + question
        conv = self.conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        image_size = [image.size]
        image_tensor = self.process_images([image], self.image_processor, self.model_config)
        input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        input_ids = input_ids.unsqueeze(0).cuda()
        return input_ids, image_tensor, image_size, prompt

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        image = Image.open(image_path).convert('RGB')
        input_ids, image_tensor, image_sizes, prompt = self.process(image, prompt)
        input_ids = input_ids.to(device='cuda', non_blocking=True)

        num_vision_tokens = 600  # This should be 600
        vision_token_start = len(self.tokenizer(prompt.split("<image>")[0], return_tensors='pt')["input_ids"][0])
        vision_token_end = vision_token_start + num_vision_tokens
        breakpoint()

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=image_tensor,
                image_sizes=image_sizes,
                do_sample=False,
                temperature=0.2,
                num_beams=1,
                max_new_tokens=3,
                use_cache=True,
                return_dict_in_generate=True,
                output_attentions=True,
            )
        outputs = self.tokenizer.batch_decode(output_ids['sequences'], skip_special_tokens=True)[0].strip()

        # save_vision_attention_scores(
        #     outputs=output_ids,
        #     file_path=f"attention_analysis/cambrian_x4_mme_3.pkl",  # Different file per encoder
        #     vision_token_start=vision_token_start,
        #     vision_token_end=vision_token_end,
        #     num_vision_tokens=num_vision_tokens,
        #     image_path=None,
        #     prompt=prompt,
        #     response=outputs,
        #     FIXED_ATTENTION_POWER=0.3
        # )

        return outputs
