import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from PIL import Image
Image.MAX_IMAGE_PIXELS = 123550720

import copy

import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from .base import BaseVLInference

CONV_TEMPLATE = {
    "liuhaotian/llava-v1.6-vicuna-7b": "vicuna_v1",
    "lmms-lab/llama3-llava-next-8b": "llava_llama_3",
    "lmms-lab/llava-onevision-qwen2-7b-ov": "qwen_1_5"
}
class LLaVAInference(BaseVLInference):
    def load_model_and_processor(self):
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path=self.model_path, model_base=None, model_name=get_model_name_from_path(self.model_path))

        self.conv_template = CONV_TEMPLATE[self.model_path]

        self.max_new_tokens = 2
        
    def run_batch_inference(self, batch):
        img = batch['image'][0]
        query_text = batch['text'][0]

        img  = Image.open(img).convert('RGB')
        width, height = img.size
        if max(width, height)>699:
            scale_factor = 640 / max(width, height)
            scale_width = int(width * scale_factor)
            scale_height = int(height * scale_factor)
            transform = T.Compose([
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                T.Resize((scale_height, scale_width), interpolation=InterpolationMode.BICUBIC)
            ])
            img = transform(img)
        image_tensor = process_images([img], self.image_processor, self.model.config)
        image_tensor = [_image.to(dtype=torch.float16, device='cuda') for _image in image_tensor]


        if self.model.config.mm_use_im_start_end:
            query_text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + query_text
        else:
            query_text = DEFAULT_IMAGE_TOKEN + '\n' + query_text


        # conv = copy.deepcopy(conv_templates[self.conv_template])
        conv = conv_templates[self.conv_template].copy()
        conv.tokenizer = self.tokenizer
        conv.append_message(conv.roles[0], query_text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        with torch.no_grad():
            outputs_ids = self.model.generate(
                    input_ids,
                    images = image_tensor,
                    image_sizes = [img.size],
                    max_new_tokens=self.max_new_tokens
                )
            answer = self.tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)

        return {"question": [query_text], "answer": answer}
    
        # batch_imgs = []
        # batch_queries = []
        # batch_input_ids = []
        # img_sizes = []
        # for img, text in zip(imgs, query_texts): 
        #     img  = Image.open(img)
        #     transform = T.Compose([
        #         T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        #         T.Resize((640, 640), interpolation=InterpolationMode.BICUBIC)
        #     ])
        #     img = transform(img)
        #     img_sizes.append(img.size)
        #     image_tensor = process_images([img], self.image_processor, self.model.config)[0]
        #     batch_imgs.append(image_tensor.unsqueeze(0).half().cuda())

            
        #     text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text

        #     batch_queries.append(text)
        #     conv = conv_templates['llava_v1'].copy()
        #     conv.append_message(conv.roles[0], text)
        #     conv.append_message(conv.roles[1], None)
        #     prompt = conv.get_prompt()

        #     input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        #     batch_input_ids.append(input_ids)

        # batch_input_ids = torch.cat(batch_input_ids, dim=0).cuda()
        # batch_imgs = torch.cat(batch_imgs, dim=0)

        
        