import io
import os
import sys
import torch
import base64
import argparse
from PIL import Image
from io import BytesIO
from bench_utils import resize_image


def get_LVLM_class(args):
    if args.lvlm_path in ["Qwen/Qwen2.5-VL-3B-Instruct",
                          "Qwen/Qwen2.5-VL-7B-Instruct",
                          "Qwen/Qwen2.5-VL-32B-Instruct",
                          "Qwen/Qwen2.5-VL-72B-Instruct"]:
        class LVLM:
            def __init__(self, args):
                from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

                self.processor = AutoProcessor.from_pretrained(
                    args.lvlm_path,
                    torch_dtype=torch.bfloat16,
                    device_map='auto',
                    min_pixels=224*224,
                    max_pixels=2048*2048,
                )
                self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    args.lvlm_path,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                )

            def run(self, img_bytes, prompt):
                from qwen_vl_utils import process_vision_info

                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image},
                                {"type": "text", "text": prompt},
                            ],
                        }
                    ]
                    text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    image_inputs, _ = process_vision_info(messages)
                    inputs = self.processor(
                        text=[text],
                        images=image_inputs,
                        padding=True,
                        return_tensors="pt",
                    ).to(self.model.device)
                else:
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": prompt},
                            ],
                        }
                    ]
                    text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    inputs = self.processor(
                        text=[text],
                        padding=True,
                        return_tensors="pt",
                    ).to(self.model.device)

                output = self.model.generate(**inputs, max_new_tokens=args.max_new_tokens)
                generated_tokens = output[0, inputs['input_ids'].size(1):]
                generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
                return generated_text

    elif args.lvlm_path == "meta-llama/Llama-3.2-11B-Vision-Instruct":
        class LVLM:
            def __init__(self, args):
                from transformers import AutoProcessor, MllamaForConditionalGeneration

                lvlm_path = "./Llama-3.2-11B-Vision-Instruct"
                self.processor = AutoProcessor.from_pretrained(
                    lvlm_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map='auto',
                )
                self.model = MllamaForConditionalGeneration.from_pretrained(
                    lvlm_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                )

            def run(self, img_bytes, prompt):
                from transformers import GenerationConfig

                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    prompt = "<|image|><|begin_of_text|>" + prompt
                    inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
                else:
                    prompt = "<|begin_of_text|>" + prompt
                    inputs = self.processor(text=prompt, return_tensors="pt").to(self.model.device)

                output = self.model.generate(
                    **inputs,
                    generation_config=GenerationConfig(max_new_tokens=args.max_new_tokens, stop_strings="<|eot_id|>"),
                    tokenizer=self.processor.tokenizer,
                )
                generated_tokens = output[0, inputs['input_ids'].size(1):]
                generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
                return generated_text

    elif args.lvlm_path == "microsoft/Phi-4-multimodal-instruct":
        class LVLM:
            def __init__(self, args):
                from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

                self.processor = AutoProcessor.from_pretrained(
                    args.lvlm_path,
                    trust_remote_code=True,
                )
                self.model = AutoModelForCausalLM.from_pretrained(
                    args.lvlm_path,
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True,
                    _attn_implementation='eager',
                )
                self.generation_config = GenerationConfig.from_pretrained(args.lvlm_path)

            def run(self, img_bytes, prompt):
                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")

                    prompt = '<|user|>\n<|image_1|>\n' + prompt + '<|end|>\n<|assistant|>\n'
                    inputs = self.processor(
                        text=prompt,
                        images=image,
                        return_tensors='pt'
                    ).to(self.model.device)
                else:
                    prompt = '<|user|>\n' + prompt + '<|end|>\n<|assistant|>\n'
                    inputs = self.processor(
                        text=prompt,
                        return_tensors='pt'
                    ).to(self.model.device)

                output = self.model.generate(**inputs,
                                             max_new_tokens=args.max_new_tokens,
                                             num_logits_to_keep=args.max_new_tokens,
                                             generation_config=self.generation_config)
                generated_tokens = output[0, inputs['input_ids'].size(1):]
                generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
                return generated_text

    elif args.lvlm_path in ["google/gemma-3-1b-it",
                            "google/gemma-3-12b-it",
                            "google/gemma-3-27b-it"]:
        class LVLM:
            def __init__(self, args):
                from transformers import AutoTokenizer, Gemma3ForCausalLM

                self.tokenizer = AutoTokenizer.from_pretrained(args.lvlm_path)
                self.model = Gemma3ForCausalLM.from_pretrained(
                    args.lvlm_path,
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    _attn_implementation='eager',
                ).eval()

            def run(self, img_bytes, prompt):
                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image},
                                {"type": "text", "text": prompt},
                            ],
                        }
                    ]
                    inputs = self.tokenizer.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        tokenize=True,
                        return_dict=True,
                        return_tensors="pt",
                    ).to(self.model.device)
                else:
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": prompt},
                            ],
                        }
                    ]
                    inputs = self.tokenizer.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        tokenize=True,
                        return_dict=True,
                        return_tensors="pt",
                    ).to(self.model.device)

                output = self.model.generate(**inputs, max_new_tokens=args.max_new_tokens)
                generated_tokens = output[0, inputs['input_ids'].size(1):]
                generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                return generated_text

    elif args.lvlm_path == "mistralai/Pixtral-12B-2409":
        class LVLM:
            def __init__(self, args):
                from mistral_inference.transformer import Transformer
                from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

                lvlm_path = "./Models/mistral_models/Pixtral"
                self.tokenizer = MistralTokenizer.from_file(f"{lvlm_path}/tekken.json")
                self.model = Transformer.from_folder(lvlm_path, dtype=torch.bfloat16)

            def run(self, img_bytes, prompt):
                from mistral_inference.generate import generate
                from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageChunk
                from mistral_common.protocol.instruct.request import ChatCompletionRequest

                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    completion_request = ChatCompletionRequest(
                        messages=[UserMessage(content=[ImageChunk(image=image), TextChunk(text=prompt)])])
                    encoded = self.tokenizer.encode_chat_completion(completion_request)
                    images, tokens = encoded.images, encoded.tokens
                    out_tokens, _ = generate([tokens], self.model, images=[images], max_tokens=args.max_new_tokens,
                                             temperature=0.35,
                                             eos_id=self.tokenizer.instruct_tokenizer.tokenizer.eos_id)
                else:
                    completion_request = ChatCompletionRequest(
                        messages=[UserMessage(content=[TextChunk(text=prompt)])])
                    encoded = self.tokenizer.encode_chat_completion(completion_request)
                    tokens = encoded.tokens
                    out_tokens, _ = generate([tokens], self.model, max_tokens=args.max_new_tokens,
                                             temperature=0.35,
                                             eos_id=self.tokenizer.instruct_tokenizer.tokenizer.eos_id)

                generated_text = self.tokenizer.decode(out_tokens[0])
                return generated_text

    elif args.lvlm_path == "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5":
        class LVLM:
            def __init__(self, args):
                from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor

                self.tokenizer = AutoTokenizer.from_pretrained(
                    args.lvlm_path,
                    trust_remote_code=True,
                    use_fast=False,
                    legacy=False
                )
                self.model = AutoModelForVision2Seq.from_pretrained(
                    args.lvlm_path,
                    torch_dtype=torch.bfloat16,
                    device_map='auto',
                    trust_remote_code=True
                )
                self.image_processor = AutoImageProcessor.from_pretrained(
                    args.lvlm_path,
                    trust_remote_code=True
                )
                self.tokenizer = self.model.update_special_tokens(self.tokenizer)
                self.tokenizer.eos_token = "<|end|>"

            def run(self, img_bytes, prompt):
                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    pixel_values = self.image_processor([image], image_aspect_ratio='anyres')["pixel_values"]
                    inputs = {"pixel_values": [pixel_values.to(self.model.device)],
                              "image_size": [image.size]}
                else:
                    inputs = {"pixel_values": None}

                prompt = '<|user|>\n' + prompt + '<|end|>\n<|assistant|>\n'
                language_inputs = self.tokenizer([prompt], return_tensors="pt")
                inputs.update(language_inputs)
                for name, value in inputs.items():
                    if isinstance(value, torch.Tensor):
                        inputs[name] = value.to(self.model.device)

                output = self.model.generate(**inputs,
                                             max_new_tokens=args.max_new_tokens,
                                             pad_token_id=self.tokenizer.pad_token_id,
                                             eos_token_id=self.tokenizer.eos_token_id)
                if img_bytes:
                    generated_tokens = output[0]
                else:
                    generated_tokens = output[0, language_inputs['input_ids'].size(1):]
                generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).split("<|end|>")[0]
                return generated_text

    elif args.lvlm_path in ["llava-hf/llava-v1.6-34b-hf",
                            "llava-hf/llava-v1.6-vicuna-7b-hf",
                            "llava-hf/llava-v1.6-vicuna-13b-hf"]:
        class LVLM:
            def __init__(self, args):
                from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

                self.processor = LlavaNextProcessor.from_pretrained(args.lvlm_path)
                self.model = LlavaNextForConditionalGeneration.from_pretrained(
                    args.lvlm_path,
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    low_cpu_mem_usage=True,
                )

            def run(self, img_bytes, prompt):
                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image},
                                {"type": "text", "text": prompt},
                            ],
                        }
                    ]
                else:
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": prompt},
                            ],
                        }
                    ]

                inputs = self.processor.apply_chat_template(
                    messages,
                    add_generation_prompt=True,
                    tokenize=True,
                    return_dict=True,
                    return_tensors="pt",
                ).to(self.model.device)

                output = self.model.generate(**inputs, max_new_tokens=args.max_new_tokens)
                generated_tokens = output[0, inputs['input_ids'].size(1):]
                generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
                return generated_text

    elif args.lvlm_path in ["meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
                            "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
                            "meta-llama/Llama-4-Scout-17B-16E-Instruct"]:
        class LVLM:
            def __init__(self, args):
                from together import Together

                self.client = Together(api_key=args.API_KEY)
                self.max_new_tokens = args.max_new_tokens

            def run(self, img_bytes, prompt):
                if img_bytes:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                    image = resize_image(image)
                    with BytesIO() as buffered:
                        image.save(buffered, format="JPEG")
                        img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": prompt},
                                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}},
                            ],
                        }
                    ]
                else:
                    messages = [
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": prompt}],
                        }
                    ]

                response = self.client.chat.completions.create(
                    model=args.lvlm_path,
                    messages=messages,
                    max_tokens=self.max_new_tokens,
                )
                return response.choices[0].message.content

    # elif args.lvlm_path == "Qwen/Qwen2-VL-7B-Instruct":
    #     class LVLM:
    #         def __init__(self, args):
    #             from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
    #
    #             self.processor = AutoProcessor.from_pretrained(
    #                 args.lvlm_path,
    #                 torch_dtype=torch.bfloat16,
    #                 device_map='auto',
    #             )
    #             self.model = Qwen2VLForConditionalGeneration.from_pretrained(
    #                 args.lvlm_path,
    #                 torch_dtype=torch.bfloat16,
    #                 device_map="auto",
    #             )
    #
    #         def run(self, img_bytes, prompt):
    #             if img_bytes:
    #                 image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    #                 messages = [
    #                     {
    #                         "role": "user",
    #                         "content": [
    #                             {"type": "image", "image": image},
    #                             {"type": "text", "text": prompt},
    #                         ],
    #                     }
    #                 ]
    #             else:
    #                 messages = [
    #                     {
    #                         "role": "user",
    #                         "content": [
    #                             {"type": "text", "text": prompt},
    #                         ],
    #                     }
    #                 ]
    #
    #             inputs = self.processor.apply_chat_template(
    #                 messages,
    #                 add_generation_prompt=True,
    #                 tokenize=True,
    #                 return_dict=True,
    #                 return_tensors="pt",
    #             ).to(self.model.device)
    #
    #             output = self.model.generate(**inputs, max_new_tokens=args.max_new_tokens)
    #             generated_tokens = output[0, inputs['input_ids'].size(1):]
    #             generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
    #             return generated_text

    # elif args.lvlm_path in ["llava-hf/llava-onevision-qwen2-7b-ov-hf", "llava-hf/llava-onevision-qwen2-72b-ov-hf"]:
    #     class LVLM:
    #         def __init__(self, args):
    #             from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
    #
    #             self.processor = AutoProcessor.from_pretrained(args.lvlm_path)
    #             self.model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    #                 args.lvlm_path,
    #                 torch_dtype=torch.bfloat16,
    #                 low_cpu_mem_usage=True,
    #                 device_map='auto',
    #             )
    #
    #         def run(self, img_bytes, prompt):
    #             if img_bytes:
    #                 conversation = [
    #                     {
    #                         "role": "user",
    #                         "content": [
    #                             {"type": "text", "text": prompt},
    #                             {"type": "image"},
    #                         ],
    #                     },
    #                 ]
    #                 prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
    #                 image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    #                 processed_inputs = self.processor(images=image, text=prompt, return_tensors='pt')
    #             else:
    #                 conversation = [
    #                     {
    #                         "role": "user",
    #                         "content": [
    #                             {"type": "text", "text": prompt},
    #                         ],
    #                     },
    #                 ]
    #                 prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
    #                 processed_inputs = self.processor(text=prompt, return_tensors='pt')
    #
    #             inputs = {}
    #             for key, value in processed_inputs.items():
    #                 if key in ['input_ids', 'batch_idx', 'image_input_idx']:
    #                     inputs[key] = value.to(device=self.model.device, dtype=torch.long)
    #                 else:
    #                     inputs[key] = value.to(device=self.model.device, dtype=torch.bfloat16)
    #
    #             output = self.model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
    #             generated_tokens = output[0, inputs['input_ids'].size(1):]
    #             generated_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
    #             return generated_text

    # elif args.lvlm_path == "allenai/Molmo-7B-D-0924":
    #     class LVLM:
    #         def __init__(self, args):
    #             from transformers import AutoProcessor, AutoModelForCausalLM
    #
    #             self.processor = AutoProcessor.from_pretrained(
    #                 args.lvlm_path,
    #                 trust_remote_code=True,
    #                 torch_dtype=torch.bfloat16,
    #                 device_map='auto',
    #             )
    #             self.model = AutoModelForCausalLM.from_pretrained(
    #                 args.lvlm_path,
    #                 trust_remote_code=True,
    #                 torch_dtype=torch.bfloat16,
    #                 device_map='auto',
    #             )
    #
    #         def run(self, img_bytes, prompt):
    #             from transformers import GenerationConfig
    #
    #             if img_bytes:
    #                 image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    #                 processed_inputs = self.processor.process(images=image, text=prompt)
    #             else:
    #                 processed_inputs = self.processor.process(text=prompt)
    #             inputs = {}
    #             for key, value in processed_inputs.items():
    #                 if key in ['input_ids', 'batch_idx', 'image_input_idx']:
    #                     inputs[key] = value.to(device=self.model.device, dtype=torch.long).unsqueeze(0)
    #                 else:
    #                     inputs[key] = value.to(device=self.model.device, dtype=torch.bfloat16).unsqueeze(0)
    #
    #             output = self.model.generate_from_batch(
    #                 inputs,
    #                 GenerationConfig(max_new_tokens=args.max_new_tokens, stop_strings="<|endoftext|>"),
    #                 tokenizer=self.processor.tokenizer
    #             )
    #             generated_tokens = output[0, inputs['input_ids'].size(1):]
    #             generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
    #             return generated_text

    elif args.lvlm_path == "debug":
        class LVLM:
            def __init__(self, args):
                pass

            def run(self, img_bytes, prompt):
                generated_text = "***Debug Contents***"
                return generated_text

    else:
        raise NotImplementedError

    return LVLM


if __name__ == '__main__':
    import requests

    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--lvlm_path', type=str, default="")
    parser.add_argument('--API_KEY', type=str, default="")
    parser.add_argument('--max_new_tokens', type=int, default=1000)
    parser.add_argument('--modality', type=str, default="real", help="[real, synthetic, triple, empty]")
    args = parser.parse_args()

    LVLM = get_LVLM_class(args)
    lvlm_model = LVLM(args)
    if args.modality in ["real", "synthetic"]:
        url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
        response = requests.get(url)
        if response.status_code == 200:
            img_bytes = BytesIO(response.content)
        else:
            raise RuntimeError
        prompt = "Describe this image."
        generated_text = lvlm_model.run(img_bytes, prompt)
    else:
        prompt = "What is China's capital city? "
        generated_text = lvlm_model.run(None, prompt)

    print(f"### Model: {args.lvlm_path}")
    print(f"### Output: \n{generated_text}")

