from google import genai
from google.genai import types
from PIL import Image
import os
import shutil

client = genai.Client()

def generate_image_from_raw(img_path, default_color, save_path):
    # client = genai.Client()

    prompt = (
        "Generate real-world stop traffic sign like the given image.",
        "Add one truck moving towards right."
        "Keep others unchanged.",
    )
    
    image = Image.open(img_path)

    response = client.models.generate_content(
        model="gemini-2.5-flash-image",
        contents=[prompt, image],
    )

    for part in response.parts:
        if part.text is not None:
            print(part.text)
        elif part.inline_data is not None:
            image = part.as_image()
            image.save(save_path)


def modify_image_color(color, img_path):
    # client = genai.Client()
    
    prompt = (
        f"Change the color of the traffic sign to '{color}' in the given image",
        "Keep other components in the background unchanged.",
    )
        # f"Change the color of the traffic sign to {color}."
        # f"Change the color of the pedestrian sign to {color}, and the rectangular plate with text \"AHEAD\" to {color}.",
        # f"Change the black part of the traffic sign to {color} and change the arrow to grey color."
        # "Keep the edge color to be black.",
    
    image = Image.open(img_path)

    response = client.models.generate_content(
        model="gemini-2.5-flash-image",
        contents=[prompt, image],
    )
    
    parts = img_path.split('.png')
    save_file = f"{parts[0]}_{color}_0.png"

    for part in response.parts:
        if part.text is not None:
            print(part.text)
        elif part.inline_data is not None:
            image = part.as_image()
            image.save(save_file)
    return save_file


def modify_image_orientation(angle, img_path):
    # client = genai.Client()
    
    prompt = (
        "For the given image, taking the bottom of the pillar as reference,",
        f"and rotate the pillar together with the sign {angle} degree clockwise in the image plane.",
        "Don't add extra texts to the image.",
        "Keep other components in the background unchanged."
    )
        # "For the given image, taking the bottom of the pillar as reference,",
        # f"and rotate the pillar together with the sign {angle} degree clockwise in the image plane.",
        # f"Rotate the triangle 'worker ahead' sign {angle} degrees clockwise.",
        # "Don't add extra texts to the image.",
        # "Don't rotate the backgound.",
    
    image = Image.open(img_path)

    response = client.models.generate_content(
        model="gemini-2.5-flash-image",
        contents=[prompt, image],
    )

    parts = img_path.split('_0.png')
    save_file = f"{parts[0]}_{angle}.png"

    for part in response.parts:
        if part.text is not None:
            print(part.text)
        elif part.inline_data is not None:
            image = part.as_image()
            image.save(save_file)
    return save_file


def resize_images(in_dir, out_dir, size=(128, 128)):
    os.makedirs(out_dir, exist_ok=True)

    count = 0
    for filename in os.listdir(in_dir):
        if filename.lower().endswith(".png"):
            img_path = os.path.join(in_dir, filename)

            with Image.open(img_path) as img:
                resized = img.resize(size, Image.Resampling.LANCZOS)
                save_path = os.path.join(out_dir, filename)
                resized.save(save_path, format="PNG")

            count += 1
            print(f"[{count}] Saved resized image: {save_path}")

    print(f"Done. {count} PNG images resized and saved to '{out_dir}'.")


if __name__ == "__main__":
    traffic_type = 'stop'
    img_idx = 0
    
    # Change colors
    img_path = f'./real_images/{traffic_type}/{traffic_type}{img_idx}.png'
    color_list = ['blue', 'green', 'grey', 'orange', 'purple', 'red', 'white','yellow']
    
    # for color in color_list:
    #     img_color = modify_image_color(color, img_path)
    #     for angle in [10, 20, 30, -10, -20, -30]:
    #     # for angle in [10]:
    #         img_angle = modify_image_orientation(angle, img_color)

    ## Fix up minor issues
    traffic_type = 'stop'
    idx = 0
    color = 'yellow'
    img_path = f'./real_data/{traffic_type}/{traffic_type}{idx}.png'
    img_tmp = f'./real_data/{traffic_type}/{traffic_type}{idx}_{color}_0.png'
    print(f"Fix image at :{img_tmp}")
    
    # img_color = modify_image_color(color, img_path)
    
    img_angle = modify_image_orientation(20, img_tmp)
    
    # for color in color_list:
        # img_color = modify_image_color(color, img_path)
    # for angle in [-30, -20, -10, 10, 20, 30]:
    #     img_tmp = f'./real_data/{traffic_type}/{traffic_type}{idx}_{color}_0.png'
    #     img_angle = modify_image_orientation(angle, img_tmp)

    # for temp_idx in range(6):
    #     for color in color_list:
    #         for angle in [20, -20]:
    #             img_color = f'./real_data/{traffic_type}/{traffic_type}{temp_idx}_{color}_0.png'
    #             img_angle = modify_image_orientation(angle, img_color)