from transformers import  AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info
import torch

from PIL import Image
Image.MAX_IMAGE_PIXELS = 123550720

import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

from .base import BaseVLInference

class QwenVLInference(BaseVLInference):
    def load_model_and_processor(self):
        if self.model_type == "qwen2_5vl":
            from transformers import Qwen2_5_VLForConditionalGeneration

            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                device_map="auto"
            )
        elif self.model_type == "qwen2vl":
            from transformers import Qwen2VLForConditionalGeneration

            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                device_map="auto"
            )

        self.processor = AutoProcessor.from_pretrained(self.model_path, padding_side="left")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)

        # self.max_new_tokens = 1


    def run_batch_inference(self, batch):
        imgs = batch['image']
        query_texts = batch['text']

        messages = []
        for img, text in zip(imgs, query_texts): 
            img  = Image.open(img)
            width, height = img.size
            if max(width, height)>640:
                scale_factor = 640 / max(width, height)
                scale_width = int(width * scale_factor)
                scale_height = int(height * scale_factor)
                transform = T.Compose([
                    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                    T.Resize((scale_height, scale_width), interpolation=InterpolationMode.BICUBIC)
                ])
                img = transform(img)
            
            message = [
                                {
                                    "role": "user",
                                    "content": [
                                        {"type": "image", "image": img},
                                        {"type": "text", "text": text}]
                                }
                            ]

            messages.append(message)

        texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages]
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
        inputs = inputs.to(self.device)

        
        with torch.no_grad():
            generated_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens, pad_token_id=self.tokenizer.eos_token_id)
            trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)]
            answers = self.processor.batch_decode(trimmed, skip_special_tokens=True)

        # return [{"question": q, "answer": a} for q, a in zip(query_texts, answers)]
        return {"question":query_texts, "answer": answers}
    

    # def run_batch_inference_with_logit(self, batch):
    #     imgs = batch['image']
    #     query_texts = batch['text']

    #     yes_token_id = self.text_processor.encode("yes", add_special_tokens=False)[0]
    #     Yes_token_id = self.text_processor.encode("Yes", add_special_tokens=False)[0]
    #     no_token_id = self.text_processor.encode("no", add_special_tokens=False)[0]
    #     No_token_id = self.text_processor.encode("No", add_special_tokens=False)[0]

    #     messages = []
    #     for img, text in zip(imgs, query_texts): 
    #         img  = Image.open(img)
    #         width, height = img.size
    #         if max(width, height)>699:
    #             scale_factor = 640 / max(width, height)
    #             scale_width = int(width * scale_factor)
    #             scale_height = int(height * scale_factor)
    #             transform = T.Compose([
    #                 T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    #                 T.Resize((scale_height, scale_width), interpolation=InterpolationMode.BICUBIC)
    #             ])
    #             img = transform(img)
            
    #         message = [
    #                             {
    #                                 "role": "user",
    #                                 "content": [
    #                                     {"type": "image", "image": img},
    #                                     {"type": "text", "text": text}]
    #                             }
    #                         ]

    #         messages.append(message)

    #     texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages]
    #     image_inputs, video_inputs = process_vision_info(messages)

    #     inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
    #     inputs = inputs.to(self.device)

        
    #     with torch.no_grad():
    #         generated_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens,do_sample=False, output_scores=True, return_dict_in_generate=True)
    #         trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids.sequences)]

    #         # print(generated_ids)

    #         # print(generated_ids.scores[0].shape)


    #         # for k in range(generated_ids.scores[0].shape[0]):
    #         #     print(generated_ids.scores[0][k][yes_token_id], generated_ids.scores[0][k][Yes_token_id],generated_ids.scores[0][k][no_token_id], generated_ids.scores[0][k][No_token_id])

    #         logits = generated_ids.scores
        
    #         yes_logits = [max(logits[0][j, yes_token_id].item(), logits[0][j, Yes_token_id].item()) for j in range(len(imgs))]
    #         # Yes_logits = [logits[0][j, Yes_token_id].item() for j in range(len(imgs))]
    #         no_logits = [max(logits[0][j, no_token_id].item(), logits[0][j, No_token_id].item()) for j in range(len(imgs))]
    #         # No_logits = [logits[0][j, No_token_id].item() for j in range(len(imgs))]

    #         probs = [torch.softmax(logits[0][j, :], dim=-1) for j in range(len(imgs))]
    #         yes_probs = [max(prob[yes_token_id].item(), prob[Yes_token_id].item()) for prob in probs]
    #         # Yes_probs = [prob[Yes_token_id].item() for prob in probs]


    #         answers = self.processor.batch_decode(trimmed, skip_special_tokens=True)

    #     # return {"question":query_texts, "answer": answers, "yes_logits": yes_logits, "Yes_logits": Yes_logits, "yes_probs": yes_probs, "Yes_probs": Yes_probs, "no_logits": no_logits, "No_logits": No_logits}
    #     return {"question":query_texts, "answer": answers, "yes_logits": yes_logits, "yes_probs": yes_probs, "no_logits": no_logits}

