# fmt:off

import flash_attn
import torch
from PIL import Image
from PIL.ImageFile import ImageFile

from llms.llm_utils import batch_call_llm, call_llm, get_avail_models, get_gen_config_fields
from llms.prompt_utils import get_messages, visualize_prompt
from llms.providers.hugging_face.setup_utils import is_flash_attn_available

print(get_gen_config_fields())
print(get_avail_models("gpt-4o-"))

print(is_flash_attn_available())


torch.cuda.empty_cache()

api_input = get_messages(
    inputs=[
        {"role": "system", "text": "You are an intelligent and helpful assistant."},
        "Describe **all** the below items.",
        # [img_input, "Image 1"],
        ["Item (1):", Image.open("test_llm/dog.png")],
        ["Item (2):", Image.open("test_llm/cat.png")],
        ["Item (3):", "Once upon a time, there was a princess who lived in a castle."],
        # [Image.open("test_llm/dog.png"), "Item (1)"],
        # [Image.open("test_llm/cat.png"), "Item (2)"],
        # ["Once upon a time, there was a princess who lived in a castle.", "Item (3)"],

        ["Provide your response as follows: <Title for Item 1> <Description for Item 1> <Title for Item 2> <Description for Item 2> <Title for Item 3> <Description for Item 3>"],
        # "Please be as detailed as possible in your description of the inputs.",
    ],
    role="user",
    name="",
    concatenate_text=False,
)

# visualize_prompt(api_input, "./vis.html")

gen_args = {
    # "model": "gpt-4o-2024-08-06",
    # "model": "gemini-2.0-flash-001",
    "model": "gemini-2.5-flash-preview-04-17",
    # "model": "Qwen/Qwen2.5-VL-32B-Instruct-AWQ",
    # "model": "Qwen/Qwen2.5-VL-3B-Instruct-AWQ",
    # "model": "Qwen/Qwen2.5-VL-3B-Instruct",
    # "model": "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
    "engine": "server",
    "endpoint": "localhost:8000",
    # "engine": "automodel",
    # "engine": "vllm",
    # "endpoint": "http://localhost:8000/v1",
    "temperature": 1,
    "top_p": .5,
    "top_k": 20,
    "max_tokens": 1000,
    "num_generations": 1,
    "stop_sequences": None,
    "presence_penalty": None,
    "frequency_penalty": None,
    "modalities": ["Text", "Image"],
    "flash_attn": True,
    "max_model_len": 16384,
    # "quant_bits": 4,
    "do_sample": False,
    "gpu_memory_utilization": 0.98,
    "thinking_budget": None,
}

api_response, model_generations = call_llm(gen_args, api_input, conversation_dir="test_llm/test_llm", usage_dir="test_llm/usage", call_id="test_llm")

# print(api_response)



# load prompt
# with open("test_llm/prompt.pkl", "rb") as f:
#     prompt = pickle.load(f)

# # api_input = prompt
# gen_args["frequency_penalty"] = None
# response, model_generations = call_llm(gen_args, api_input, conversation_dir="test_llm/test_llm", usage_dir="test_llm/usage", call_id="test_llm")

# for generation in model_generations:
#     print(generation.text())

# response[0]["prompt"][3]['content']
# len(response[0]["prompt"][3]['content'])


# ## Sanity check
# sys.exit(0)
# # Pickle the prompt
# with open("test_llm/prompt.pkl", "wb") as f:
#     pickle.dump(response[0]["prompt"], f)


# from transformers import AutoModelForCausalLM, AutoProcessor, Qwen2_5_VLForConditionalGeneration

# model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", device_map="auto", torch_dtype = "bfloat16")
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")



# # Load the prompt
# import pickle

# with open("test_llm/prompt.pkl", "rb") as f:
#     prompt = pickle.load(f)

# print(prompt)


# text_prompt = processor.apply_chat_template(prompt, add_generation_prompt=True)

# from qwen_vl_utils import process_vision_info

# image_inputs, video_inputs = process_vision_info(prompt)


# inputs = processor(
#     text=[text_prompt],
#     images=image_inputs,
#     videos=video_inputs,
#     padding=True,
#     return_tensors="pt",
# )

# inputs.to(model.device)
# # Inference: Generation of the output
# generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=None, top_p=None, top_k=None)
# generated_ids_trimmed = [
#     out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
# ]
# output_text = processor.batch_decode(
#     generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, 
# )
# print(output_text)


# # Call llm:
# # local huggingface:
# #"The image shows a young, light-colored puppy sitting on a grassy area dotted with small orange flowers. The puppy has a fluffy coat with a mix of golden and cream tones, and its mouth is slightly open as if it's smiling or panting. The background consists of green grass and more orange flowers, suggesting an outdoor setting in a garden or park. The overall mood of the image is cheerful and lively."

# # server huggingface:
# # The image shows a young, light-colored puppy sitting on a grassy area dotted with small orange flowers. The puppy has a fluffy coat with a mix of golden and cream tones, and its mouth is slightly open as if it's smiling or panting. The background consists of green grass and more orange flowers, suggesting an outdoor setting in a garden or park. The overall mood of the image is cheerful and lively.

# # Bare huggingface:
# #"The image shows a young, light-colored puppy sitting on a grassy area dotted with small orange flowers. The puppy has a fluffy coat with a mix of golden and cream tones, and its mouth is slightly open as if it's smiling or panting. The background consists of green grass and more orange flowers, suggesting an outdoor setting in a garden or park. The overall mood of the image is cheerful and lively."
# #"The image shows a young, light-colored puppy sitting on a grassy area dotted with small orange flowers. The puppy has a fluffy coat with a mix of golden and cream tones, and its mouth is slightly open as if it's smiling or panting. The background consists of green grass and more orange flowers, suggesting an outdoor setting in a garden or park. The overall mood of the image is cheerful and lively."



# # 
# # Using VLLM OpenAI chat completion prompt, model perceives correct image-text pairs
# # client = OpenAI(api_key="")

# # client.chat.completions.create(
# #     model='gpt-4o-mini',  # gpt-4o-mini
# #     messages=response[0]["prompt"]
# # )





# Qwen2.5-VL-3B-Instruct, no sampling
# The image features a tabby cat with a mix of dark and light brown stripes. The cat is sitting upright on what appears to be a stone or concrete ledge, with its front paws resting on the edge. The background is out of focus, showing some bare branches against a blue sky, suggesting it might be late autumn or winter. The cat's gaze is directed forward, and its expression is calm and attentive.
# The image features a tabby cat with a mix of dark and light brown stripes. The cat is sitting upright on what appears to be a stone or concrete ledge, with its front paws resting on the edge. The background is out of focus, showing some bare branches against a blue sky, suggesting it might be late autumn or winter. The cat's gaze is directed forward, and its expression is calm and attentive.
# The image features a tabby cat with a mix of dark and light brown stripes. The cat is sitting upright on what appears to be a stone or concrete ledge, with its front paws resting on the edge. The background is out of focus, showing some bare branches against a blue sky, suggesting it might be winter or early spring. The cat's gaze is directed forward, and its expression is calm and attentive.

msg1 = [
    {"role": "system", "text": "You are an intelligent and helpful assistant."},
    "Describe **all** the below items.",
    "test_llm/dog.png",
    "test_llm/cat.png"
]
msg2 = [
    {"role": "system", "text": "You talk like a pirate."},
    "Describe **all** the below items.",
    "test_llm/cat.png"
]

conversation_dir_1 = "./test_llm/conversation_dirs/test_llm"
conversation_dir_2 = "./test_llm/conversation_dirs/test_llm"
call_id_1 = "call1"
call_id_2 = "call2"
all_msgs = []
conversation_dirs = []
call_ids = []
usage_dirs = []
for i in range(2):
    all_msgs.append(msg1)
    conversation_dirs.append(conversation_dir_1)
    usage_dirs.append(conversation_dir_1)
    call_ids.append(call_id_1)
    all_msgs.append(msg2)
    conversation_dirs.append(conversation_dir_2)
    call_ids.append(call_id_2)
    usage_dirs.append(conversation_dir_2)

# batch_response, batch_model_generations = batch_call_llm(gen_args, msg_batch, conversation_dirs=conversation_dirs, usage_dirs=usage_dirs, call_ids=call_ids, verbose=True)
batch_response, batch_model_generations = batch_call_llm(gen_args, all_msgs, max_batch_size=-1, 
        conversation_dirs=conversation_dirs, call_ids=call_ids, usage_dirs=usage_dirs, verbose=True,max_api_keys=2, order_by_payload_size=True)
