import os
import openai
import json
import time
import base64
import requests
import copy
from PIL import Image, ImageDraw
 
openai.api_key = "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"
model_name = "gpt-4o" #"gpt-4-turbo"
temperature_set = 0
max_tokens = 2000
 
def read_text_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()
    return text
 
def get_image_list(image_folder):
    image_list = []
    for img in os.listdir(image_folder):
        if img.endswith(".png"):
            image_list.append(f"{image_folder}/{img}")
    return image_list
 
def combine_images_side_by_side(image_paths, output_path):
    # Load all images
    images = [Image.open(path) for path in image_paths]
   
    # Find the total width and maximum height of the final image
    total_width = sum(image.width for image in images)
    max_height = max(image.height for image in images)
   
    # Create a new blank image with the total width and max height
    combined_image = Image.new("RGB", (total_width, max_height))
   
    # Paste each image next to each other
    x_offset = 0
    for img in images:
        combined_image.paste(img, (x_offset, 0))
        x_offset += img.width
   
    # Save the combined image
    combined_image.save(output_path)
    print(f"Combined image saved at {output_path}")
    return list(output_path)
 
def encode_image_func(image_path):
    """
        Encodes an image for VLM query.
    """
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')
   
def format_prompt(text_prompt, image_list):
    """
    Encodes the conversation for querying the OpenAI api.
    """
    formatted_prompt = [{"type": "text", "text": text_prompt}]
    for img_path in image_list:
        if not os.path.exists(img_path):
            print(f"Error: Image doesn't exist at provided path: {img_path}")
            exit(1)
 
        base64_image = encode_image_func(img_path)
        formatted_prompt.append(
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/png;base64,{base64_image}"}
            }
        )
       
    return formatted_prompt
 
def get_response(text_prompt, image_list):
    formatted_prompt = format_prompt(text_prompt, image_list)
 
    headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {openai.api_key}"
    }
   
    payload = {
        "model": model_name,
        "temperature": temperature_set,
        "messages": [{"role": "user", "content": formatted_prompt}],
        "max_tokens": max_tokens
    }
   
    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    return response.json()['choices'][0]['message']['content']
 
if __name__ == "__main__":
    # text_prompt = read_text_file("robomimic_prompts/tool_hang_new_prompt.txt") + read_text_file("robomimic_prompts/tool_hang.txt")
    # prediction = get_response(text_prompt=text_prompt, image_list=["robomimic_images/000.png"])
    # text_prompt = read_text_file("robomimic_prompts/tool_hang_new_prompt2.txt") + "\nPrediction:\n" + prediction
    # # image_list = get_image_list("robomimic_images")
    # # combined_image = combine_images_side_by_side(image_list, "robomimic_images/combined_image.png")
    # print(get_response(text_prompt=text_prompt, image_list=["robomimic_images/125.png"]))

    # for 1-pass method - naive method
    text_prompt = read_text_file("instruction/tool_hang_1pass_prompt.txt") + read_text_file("instruction/tool_hang_openai.txt")
    # prediction = get_response(text_prompt=text_prompt, image_list=["robomimic_images/000.png"])
    # text_prompt = read_text_file("robomimic_prompts/tool_hang_new_prompt2.txt") + "\nPrediction:\n" + prediction
    # image_list = get_image_list("robomimic_images")
    # combined_image = combine_images_side_by_side(image_list, "robomimic_images/combined_image.png")
    print(get_response(text_prompt=text_prompt, image_list=["tests/figure/tool_hang/ph/mlp/rollout_001/125.png"]))
 
 