import base64
from io import BytesIO
import os

import torch
from PIL import Image, ImageDraw

from models.llava_uground.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from models.llava_uground.conversation import conv_templates
from models.llava_uground.mm_utils import tokenizer_image_token, process_images, pre_resize_by_width
from models.llava_uground.model.builder import load_pretrained_model
from models.llava_uground.utils import disable_torch_init

from extract_trajectory_html import process_html_file

def screenshot_to_file(execution_data, state_idx, file_path):
    img_b64 = execution_data['states'][state_idx]['screenshot']
    img = b64_to_pil(img_b64)
    img.save(file_path)


def b64_to_pil(img_b64: str) -> Image.Image:
    img_b64 = img_b64.split(",")[1]
    img_bytes = base64.b64decode(img_b64)
    return Image.open(BytesIO(img_bytes))


def draw_circle_on_image(image, coordinates, radius=20, color=(255, 0, 0)):
    x, y = coordinates
    img_width, img_height = image.size

    if not (0 <= x <= img_width):
        return image

    if not (0 <= y <= img_height):
        return image

    draw = ImageDraw.Draw(image)

    left_up_point = (x - radius, y - radius)
    right_down_point = (x + radius, y + radius)

    draw.ellipse([left_up_point, right_down_point], outline=color, width=10)

    return image


# flush torch cache
torch.cuda.empty_cache()
disable_torch_init()
model_path = 'osunlp/UGround'
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_path)

prompt_template = conv_templates['llava_v1'].copy()

target_ids = [38]
coordinates = []
img_path = "screenshot.png"
for target_id in target_ids:
    description = f"the element surrounded by the box with ID {target_id}"
    # description = f'the element "brown cereal"' 
    qs = f"In the screenshot, what are the pixel coordinates (x, y) of {description}?"
    qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    prompt = prompt_template.copy()
    prompt.append_message(prompt_template.roles[0], qs)
    prompt.append_message(prompt_template.roles[1], None)
    prompt = prompt.get_prompt()
    image = Image.open(os.path.join(img_path)).convert('RGB')

    
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    resized_image, pre_resize_scale = pre_resize_by_width(image)

    image_tensor, image_new_size = process_images([resized_image], image_processor, model.config)

    print(f"Generating coordinates for target {target_id}")
    print(prompt)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor.half().cuda(),
            image_sizes=[image_new_size],
            temperature=0,
            top_p=None,
            do_sample=False,
            num_beams=1,
            max_new_tokens=16384,
            use_cache=True)

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    # print("resized_outputs:",outputs)
    fixed_coordinates = tuple(x / pre_resize_scale for x in eval(outputs))
    # print("fixed_outputs:",fixed_coordinates)
    coordinates.append(fixed_coordinates)


out_image = resized_image
for fixed_coordinates in coordinates:
    out_image = draw_circle_on_image(out_image, fixed_coordinates)
out_image.save("out_img_som.png")

# out_image_nosom = Image.open(os.path.join('screenshot_nosom.png')).convert('RGB')
# for fixed_coordinates in coordinates:
#     out_image_nosom = draw_circle_on_image(out_image_nosom, fixed_coordinates)

# out_image_nosom.save("out_img_no_som.png")