import torch
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from ..base import BaseModel
from ...smp import *
from ...dataset import DATASET_TYPE
import copy


from .save_attention_features import *
import warnings
warnings.simplefilter('ignore', UserWarning)


class Eagle(BaseModel):
    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self,
                 model_path='../../checkpoints/Eagle-X5-7B',
                 **kwargs):
        try:
            from vlmeval.vlm.eagle.model.builder import load_pretrained_model
            from vlmeval.vlm.eagle.utils import disable_torch_init
            from vlmeval.vlm.eagle.mm_utils import get_model_name_from_path
        except Exception as e:
            logging.critical('''Please install eagle before using Eagle,
            you can install it from "https://github.com/NVlabs/EAGLE.git"''')
            raise e

        warnings.warn('Please install the latest version of eagle from github before you evaluate the Eagle model.')
        assert osp.exists(model_path) or splitlen(model_path) == 2
        model_name = get_model_name_from_path(model_path)

        self.tokenizer, self.model, self.image_processor, self.context_len = (
            load_pretrained_model(model_path, None, model_name, False, False, device_map="auto", torch_dtype=torch.bfloat16)
        )
        self.model.to(torch.bfloat16)

        self.model.eval()
        self.conv_mode = 'vicuna_v1'

        default_kwargs = dict(
            do_sample=False,
            temperature=0,
            top_p=0.5,
            num_beams=1,
            max_new_tokens=512,
            use_cache=True
        )

        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
        torch.cuda.empty_cache()

    def generate_inner(self, message, dataset=None):
        try:
            from vlmeval.vlm.eagle.constants import (IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
                                         DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
            from vlmeval.vlm.eagle.conversation import conv_templates, SeparatorStyle
            from vlmeval.vlm.eagle.mm_utils import tokenizer_image_token, process_images, KeywordsStoppingCriteria
        except Exception as e:
            logging.critical('''Please install eagle before using Eagle,
            you can install it from "https://github.com/NVlabs/EAGLE.git"''')
            raise e

        kwargs = self.kwargs

        images = []
        prompt = ''

        for s in message:
            if s['type'] == 'image':
                images.append(s['value'])
            elif s['type'] == 'text':
                prompt += s['value']

        DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN * len(images)
        if self.model.config.mm_use_im_start_end:
            prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        images = [Image.open(s).convert('RGB') for s in images]

        image_tensor = process_images(images, self.image_processor, self.model.config)
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        input_ids = input_ids.to(device='cuda', non_blocking=True)
        image_tensor = image_tensor.to(dtype=torch.bfloat16, device='cuda', non_blocking=True)

        with torch.inference_mode():
            # breakpoint()
            output_ids = self.model.generate(
                input_ids.unsqueeze(0),
                images=image_tensor,
                image_sizes=[img.size for img in images],
                return_dict_in_generate=True,
                output_attentions=True,
                **kwargs
            )

        outputs = self.tokenizer.batch_decode(output_ids['sequences'], skip_special_tokens=True)[0].strip()


        num_vision_tokens = 1024  # This should be 1024
        vision_token_start = len(self.tokenizer(prompt.split("<image>")[0], return_tensors='pt')["input_ids"][0])
        vision_token_end = vision_token_start + num_vision_tokens

        # save_vision_attention_scores(
        #     outputs=output_ids,
        #     file_path=f"attention_analysis/eagle_x4_7b_mme_3.pkl",  # Different file per encoder
        #     vision_token_start=vision_token_start,
        #     vision_token_end=vision_token_end,
        #     num_vision_tokens=num_vision_tokens,
        #     image_path=None,
        #     prompt=prompt,
        #     response=outputs,
        #     FIXED_ATTENTION_POWER=0.3
        # )


        return outputs

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(['MMMU'], dataset):
            return False
        if DATASET_TYPE(dataset) == 'MCQ' or dataset == 'MMVet':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)
        question = line['question']
        if dataset == 'MMVet':
            prompt = question + '\nAnswer the question directly. '
        elif DATASET_TYPE(dataset) == 'MCQ':
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = ''
            for key, item in options.items():
                options_prompt += f'{key}. {item}\n'

            hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
            prompt = f'Hint: {hint}\n' if hint is not None else ''
            prompt += f'{question}\n'
            prompt += (
                f'{options_prompt}\nAnswer with the option’s letter from the given choices directly. '
                if len(options) else 'Answer the question directly. '
            )
        else:
            raise NotImplementedError

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])
        return message
