from transformers import (
    MllamaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor,
    Qwen2_5_VLForConditionalGeneration
)
import os
from openai import OpenAI
from qwen_vl_utils import process_vision_info
import torch
from time import sleep
import numpy as np
import io
from PIL import Image
import base64


def get_ensemble_model_func(model_name):
    if 'llama' in model_name:
        return answer_query_llama_ensemble
    elif 'qwen' in model_name:
        return answer_query_qwen_ensemble
    elif 'gpt' in model_name:
        return answer_query_gpt_ensemble
    else:
        raise ValueError('Invalid model name')

def get_model(model_name):
    if model_name == 'llama':
        model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
        model = MllamaForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id)
        model_answer_func = answer_query_llama
    elif model_name == 'llama_big':
        model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
        model = MllamaForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            # load_in_8bit=True,
            device_map="auto",
            # tp_plan="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id)
        model_answer_func = answer_query_llama
    elif model_name == 'qwen_big':
        model_id = "Qwen/Qwen2.5-VL-72B-Instruct"
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id, max_pixels = 1280 * 28 * 28)
        model_answer_func = answer_query_qwen
    elif model_name == 'qwen_big_budget':
        model_id = "Qwen/Qwen2.5-VL-72B-Instruct-AWQ"
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id, max_pixels = 1280 * 28 * 28)
        model_answer_func = answer_query_qwen
    elif model_name == 'qwen':
        model_id = "Qwen/Qwen2-VL-7B-Instruct"
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id, max_pixels = 1280 * 28 * 28)
        model_answer_func = answer_query_qwen
    elif model_name == 'qwen_new':
        model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id, max_pixels = 1280 * 28 * 28)
        model_answer_func = answer_query_qwen
    elif model_name == 'qwen_new_small':
        model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_id, max_pixels = 1280 * 28 * 28)
        model_answer_func = answer_query_qwen
    elif model_name == 'gpt':
        model = OpenAI(
            api_key=os.environ.get("OPENAI_API_KEY"),
        )
        processor = None
        model_answer_func = answer_query_gpt
    else:
        raise ValueError('Invalid model name')
    return model, processor, model_answer_func


def img_to_base64_str(img):
    buffered = io.BytesIO()
    img.save(buffered, format="PNG")
    buffered.seek(0)
    img_byte = buffered.getvalue()
    img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()
    return img_str


def encode_imgs(img_list, img_type="image"):
    from PIL import Image
    encoded_imgs = []
    for img in img_list:
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        encoded_imgs.append(img_to_base64_str(img))
        
    gpt_img_msg = [
        {
          "type": img_type,
          "image":  encoded_img
        }
        for encoded_img in encoded_imgs
    ]
    return gpt_img_msg


def encode_imgs_gpt(img_list):
    from PIL import Image
    encoded_imgs = []
    for img in img_list:
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        encoded_imgs.append(img_to_base64_str(img))
        
    gpt_img_msg = [
        {
          "type": "image_url",
          "image_url": {
            "url": encoded_img
          },
        }
        for encoded_img in encoded_imgs
    ]
    return gpt_img_msg


def answer_query_qwen(model, processor, prompt, images):
    if not isinstance(images, list):
        images = [images]
    messages = [
        {
            "role": "user",
            "content": encode_imgs(images) + 
            [
                {"type": "text", "text": prompt},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    with torch.no_grad():
        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, max_new_tokens=10000)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    torch.cuda.empty_cache()
    return output_text


def answer_query_llama(model, processor, prompt, image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text",
             "text": prompt}
        ]}
    ]
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors="pt",
    ).to(model.device)

    output = model.generate(**inputs, max_new_tokens=1000)
    return processor.decode(
        output[0]).split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
    )[1].strip()


def get_gpt_response(client, classification_prompt, image_msg, temperature=1.0):
    response = client.chat.completions.create(
      model="gpt-4o-mini",
      messages=[
        # {
        #     'role': "system",
        #     "content": "You are an expert image classifier."
        #     },
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": classification_prompt,
            },] + image_msg,
        }
      ],
      max_tokens=10000,
      temperature=temperature
    )
    sleep(1.0)
    return response.choices[0].message.content


def answer_query_qwen_ensemble(model, processor, prompt, images, ensemble_num=1, visual_instruction=None,
                               **decoding_kwargs):
    if not isinstance(images, list):
        images = [images]
    if visual_instruction is not None:
        images = [visual_instruction] + images
    messages = [
        {
            "role": "user",
            "content": encode_imgs(images) + 
            [
                {"type": "text", "text": prompt},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, _ = process_vision_info(messages)
    inputs = processor(
        text=[text] * ensemble_num,
        images=image_inputs * ensemble_num,
        videos=None,
        padding=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        inputs = inputs.to("cuda")
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=1000,
            **decoding_kwargs
        )
    prompt_len = inputs.input_ids.shape[1]
    generated_ids_trimmed = [
        out_ids[prompt_len:] for out_ids in generated_ids
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text



def answer_query_llama_ensemble(model, processor, prompt, image, ensemble_num=1, 
                               **decoding_kwargs):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text",
             "text": prompt}
        ]}
    ]
    input_text = processor.apply_chat_template(messages * ensemble_num, add_generation_prompt=True)
    inputs = processor(
        [image] * ensemble_num,
        input_text,
        add_special_tokens=False,
        return_tensors="pt",
    ).to(model.device)
    outpus = []
    for _ in range(ensemble_num):
        output = model.generate(
            **inputs,
            max_new_tokens=1000,
            **decoding_kwargs
        )
        output = processor.decode(
            output[0]).split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
        )[1].strip()
        outpus.append(output)
    return outpus


def answer_query_gpt(model, processor, prompt, image):
    img_msg = encode_imgs_gpt([image])
    response = get_gpt_response(model, prompt, img_msg)
    return response

def answer_query_gpt_ensemble(model, processor, prompt, image, ensemble_num=1,
                               **decoding_kwargs):
    img_msg = encode_imgs_gpt([image])
    outputs = []
    for _ in range(ensemble_num):
        response = get_gpt_response(model, prompt, img_msg, temperature=0.6)
        outputs.append(response)
    return outputs
