from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, AutoModelForCausalLM
import torch
import base64
import argparse
from PIL import Image, ImageDraw, ImageFont
import openai
from io import BytesIO
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

import math
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

from google import genai
from google.genai import types
import anthropic

import os

def split_model(model_name):
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image, input_size=448, max_num=12):
    # image = Image.open(image_file).convert('RGB')
    # image = crop_image(image)
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def crop_image(image:Image.Image):
    # take the width and height of the image
    width, height = image.size

    image = image.crop((0, 750, width, height))

    return image

class VictimModel:
    def __init__(self, model_name):
        self.model_name = model_name
        self.model_dict = {"llava-ov": "llava-hf/llava-onevision-qwen2-7b-ov-hf",
                            "internvl3-8b":"OpenGVLab/InternVL3-8B",
                            "internvl3-14b":"OpenGVLab/InternVL3-14B",
                            "internvl3-38b":"OpenGVLab/InternVL3-38B",
                            "qwen2.5-vl-7b":"Qwen/Qwen2.5-VL-7B-Instruct",
                            "qwen2.5-vl-32b":"Qwen/Qwen2.5-VL-32B-Instruct",
                            "qwen2.5-3b":"Qwen/Qwen2.5-3B-Instruct",
                            "gemini-2.5-flash": "gemini-2.5-flash",
                            "gemini-2.5-pro": "gemini-2.5-pro",
                            "gpt-4o-mini":"gpt-4o-mini",
                            "gpt-4.1-mini":"gpt-4.1-mini",
                            "gpt-4o":"gpt-4o-2024-08-06",
                            "gpt-4.1":"gpt-4.1-2025-04-14",
                            "claude-sonnet-4":"claude-sonnet-4-20250514",
                            "claude-haiku-3.5":"claude-3-5-haiku-20241022"}
        assert self.model_name in self.model_dict.keys(), f"Model {self.model_name} not found"

        self.processor, self.model = self.load_model(self.model_name)

    def load_model(self, model_name):
        if model_name == "llava-ov":
            processor = AutoProcessor.from_pretrained(self.model_dict[model_name]) 
            model = LlavaOnevisionForConditionalGeneration.from_pretrained(
                self.model_dict[model_name],
                torch_dtype=torch.float16,
                # device_map=device
            )
        elif model_name == "internvl3-8b":
            model = AutoModel.from_pretrained(
            self.model_dict[model_name],
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True).eval().cuda()

            processor = AutoTokenizer.from_pretrained(self.model_dict[model_name], trust_remote_code=True, use_fast=False)

        elif model_name == "internvl3-14b":
            model = AutoModel.from_pretrained(
            self.model_dict[model_name],
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True).eval().cuda()

            processor = AutoTokenizer.from_pretrained(self.model_dict[model_name], trust_remote_code=True, use_fast=False)

        elif model_name == "internvl3-38b":
            # Using original code from internvl3-38b to activate multi-gpu inference

            device_map = split_model(self.model_dict[model_name])
            model = AutoModel.from_pretrained(
                self.model_dict[model_name],
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                use_flash_attn=True,
                trust_remote_code=True,
                device_map=device_map).eval()

            processor = AutoTokenizer.from_pretrained(self.model_dict[model_name], trust_remote_code=True, use_fast=False)

        elif model_name == "qwen2.5-vl-7b" or model_name == "qwen2.5-vl-32b":
            model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            self.model_dict[model_name], 
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto")

            # default processer
            processor = AutoProcessor.from_pretrained(self.model_dict[model_name], use_fast=True)

        elif "gemini" in model_name:
            model = [genai.Client(api_key=os.getenv("GEMINI_API_KEY")), self.model_dict[model_name]]

            # default processer
            processor = None

        elif 'gpt' in model_name:
            model = self.model_dict[model_name]
            processor = None
        
        elif "claude" in model_name:
            model = self.model_dict[model_name]
            processor = None

        elif model_name == "qwen2.5-3b":
            model = AutoModelForCausalLM.from_pretrained(
                self.model_dict[model_name],
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                trust_remote_code=True
            )
            
            # Move model to CUDA if available and not using device_map
            if torch.cuda.is_available() and not hasattr(model, 'hf_device_map'):
                model = model.cuda()
            
            processor = AutoTokenizer.from_pretrained(self.model_dict[model_name])


        print(f"Loaded {model_name} model")
        return processor, model


    @staticmethod
    def image_to_base64(image):
        """Convert PIL Image to base64 string"""
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

    def get_response(self, image, prompt):
        response = None
        if self.model_name == "llava-ov":
            image_base64 = self.image_to_base64(image)
            response = self.send_to_llava_ov_7b(image_base64, prompt, self.processor, self.model, self.device)
        elif self.model_name == "internvl3-8b" or self.model_name == "internvl3-14b" or self.model_name == "internvl3-38b":
            response = self.send_to_internvl3(image, prompt, self.processor, self.model)
        elif self.model_name == "qwen2.5-vl-7b" or self.model_name == "qwen2.5-vl-32b":
            assert isinstance(image, str), "Image path must be a string"
            response = self.send_to_qwen2_5_vl(image, prompt, self.processor, self.model)
        elif self.model_name == "gemini-2.5-flash" or self.model_name == "gemini-2.5-pro":
            response = self.send_to_gemini_2_5(image, prompt, self.processor, self.model)
        elif "gpt" in self.model_name:
            response = self.send_to_chatgpt(image, prompt, self.processor, self.model)
        elif self.model_name == "qwen2.5-3b":
            response = self.send_to_qwen2_5_3b(prompt, self.processor, self.model)
        elif "claude" in self.model_name:
            response = self.send_to_claude(image, prompt, self.processor, self.model)
        print(f"Response: {response}")
        return response

    @staticmethod
    def send_to_llava_ov_7b(image_base64, prompt, processor, model):

        conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image", "base64": image_base64},
                {"type": "text", "text": prompt},
            ],
        },
        ]
        inputs = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
        inputs = inputs.to(model.device, torch.float16)

        # autoregressively complete prompt
        output = model.generate(**inputs, max_new_tokens=100)
        output_text = processor.decode(output[0], skip_special_tokens=True)
        # print(output_text)
        return output_text

    @staticmethod
    def send_to_internvl3(image, prompt, processor, model):
        # image = Image.open(image_file).convert('RGB')
        pixel_values = load_image(image, max_num=12).to(torch.bfloat16).cuda()
        generation_config = dict(max_new_tokens=1024, do_sample=True)

        response = model.chat(processor, pixel_values, prompt, generation_config)

        return response


    @staticmethod
    def send_to_qwen2_5_vl(image_path, prompt, processor, model):
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": f"file://{image_path}",
                    },
                    {"type": "text", "text": f"{prompt}"},
                ],
            }
        ]

        # Preparation for inference
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")

        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, max_new_tokens=1024)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        # print(output_text)
        # import pdb; pdb.set_trace()
        return output_text[0]

    @staticmethod
    def send_to_gemini_2_5(image_path, prompt, processor, model):

        print(f"Sending {image_path} to {model[1]}")

        with open(image_path, 'rb') as f:
            image_bytes = f.read()

        mime_type = "image/jpeg" if image_path.endswith(".jpg") else "image/png"

        response = model[0].models.generate_content(
            model=model[1],
            contents=[
            types.Part.from_bytes(
                data=image_bytes,
                mime_type=mime_type,
            ),
            f'{prompt}'
            ],
            config=types.GenerateContentConfig(
                temperature=0.1,
                # thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
            )
        )

        response = response.text
        return response

    @staticmethod
    def send_to_chatgpt(image_base64, prompt, processor, model):
        """Send image and prompt to ChatGPT API"""
        try:
            response = openai.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": ""},
                {"role": "user", "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {
                        "url": f"data:image/png;base64,{image_base64}"}
                    }
                ]}
            ],
            temperature=0.1,
            max_tokens=1000
            )
            output_text = response.choices[0].message.content
            # print(output_text)
            return output_text
        except Exception as e:
            print(f"Error calling ChatGPT API: {e}")
            return None

    @staticmethod
    def send_to_qwen2_5_3b(prompt, processor, model):
        
        # Prepare messages with system prompt and conversation history
        messages = [
            {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        
        # Apply chat template
        text = processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Tokenize input
        model_inputs = processor([text], return_tensors="pt")
        if torch.cuda.is_available():
            model_inputs = model_inputs.to("cuda")
        
        # Generate response
        try:
            with torch.no_grad():
                generated_ids = model.generate(
                    **model_inputs,
                    max_new_tokens=1024,
                    do_sample=True,
                    temperature=0.1,
                    pad_token_id=processor.eos_token_id
                )
            
            # Decode response
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]
            response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            
            return response
            
        except Exception as e:
            print(f"Error generating response using Qwen2.5-3B: {e}")
            return "Sorry, I encountered an error while generating a response."

    @staticmethod
    def send_to_claude(image_PIL, prompt, processor, model):

        
        buffered = BytesIO()
        image_PIL.save(buffered, format="PNG")
        image_data = base64.standard_b64encode(buffered.getvalue()).decode('utf-8')
        
        client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
        
        response = client.messages.create(
            model=model,
            max_tokens=1024,
            temperature=0.1,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": image_data
                            }
                        },
                        {
                            "type": "text",
                            "text": prompt
                        }
                    ],
                }
            ],
        )
        return response.content[0].text

