from PIL import Image, ImageDraw, ImageFont
from furniture_bench_api.utils.pose_utils import relative_to_base
import numpy as np

def draw_text(V, P, width, height, draw, obj_pos, text: str):
    obj_pos_hom = np.concatenate([obj_pos, [1]])  # convert to homogeneous coordinates
    cam_pos_hom = V @ obj_pos_hom
    clip_space = P @ cam_pos_hom
    ndc = clip_space[:3] / clip_space[3]

    x = (ndc[1] + 1) * 0.5 * width
    y = (ndc[0] + 0.5) * height
    # y = (ndc[0] + 1) * 0.5 * height
    # y = (1 - ndc[0]) * height
    # y += 0.12 * height # FIXME: hot fix since transformation does not work for some reason

    # Try to load a font with larger size, fallback to default if not available
    try:
        font = ImageFont.truetype("arial.ttf", size=20)  # Increased from default
    except OSError:
        try:
            font = ImageFont.truetype("DejaVuSans.ttf", size=20)  # Common Linux font
        except OSError:
            font = ImageFont.load_default()  # Fallback to default font

    # Create larger text with background
    bbox = draw.textbbox((0, 0), text, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]

    # Draw background rectangle
    text_x, text_y = x - text_width / 2, y
    padding = 4
    draw.rectangle(
        [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding],
        fill="#00000075"
    )

    draw.text((text_x, text_y), text, fill="white", font=font)

def get_image_with_labels(furniture_bench_env):
    env = furniture_bench_env.env.env.env
    cam_cfg = env.camera_cfg

    part_names = [p.name for p in env.furniture.parts]
    part_poses = furniture_bench_env.get_observation()["parts_poses"][0].reshape(-1, 7).cpu().numpy()
    # part_poses = [furniture_bench_env.get_object_origin(n) for n in part_names]

    P, V = env.get_front_projection_view_matrix()

    V = np.asarray(V)
    width, height = cam_cfg.width, cam_cfg.height

    image = Image.fromarray(env.get_observation()["color_image2"][0].cpu().numpy()).convert("RGBA")
    canvas = Image.new("RGBA", image.size, (255, 255, 255, 0))
    draw = ImageDraw.Draw(canvas)

    for part_name, part_pose in zip(part_names, part_poses):
        obj_pos = part_pose[:3]

        draw_text(V, P, width, height, draw, obj_pos, f"{part_name}:part")

    robot_pose = furniture_bench_env.get_current_pose(at_flange=True)[:7]
    robot_pose = relative_to_base(robot_pose, furniture_bench_env, inverse=True)
    robot_pose = robot_pose[0, :3].cpu().numpy()
    robot_pose[2] += 0.15 # FIXME: hack to move lable a bit up
    draw_text(V, P, width, height, draw, robot_pose, "arm:robot")

    image = Image.alpha_composite(image, canvas)

    return image.convert("RGB")