import os
import tkinter as tk
from tkinter import ttk
from tkinter import filedialog
from PIL import Image, ImageTk
import json
import uuid
import mimetypes
import base64
import requests
from PIL import Image
from io import BytesIO
from openai import OpenAI
import argparse
from utils import generate_image,edit_image

# This GUI application is for choose, filter and edit images, avoid low-quality image editing. Generate normal images before use this app.


parser = argparse.ArgumentParser()
parser.add_argument('--ccs_content_file', type=str, default='')
parser.add_argument('--images_dir', type=str, default='')

args = parser.parse_args()
images_dir = args.images_dir
ccs_content_file = args.ccs_content_file

with open("env.json",'r') as f:
    api_keys = json.load(f)


client = OpenAI(
    base_url="",
    api_key=api_keys['OPENAI_API_KEY'],
    timeout=120
)

def encode_image(image_path):
    if image_path.startswith("http"):
        user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
        request_kwargs = {
            "headers": {"User-Agent": user_agent},
            "stream": True,
        }

        # Send a HTTP request to the URL
        response = requests.get(image_path, **request_kwargs)
        response.raise_for_status()
        content_type = response.headers.get("content-type", "")

        extension = mimetypes.guess_extension(content_type)
        if extension is None:
            extension = ".download"

        fname = str(uuid.uuid4()) + extension
        download_path = os.path.abspath(os.path.join("downloads", fname))

        with open(download_path, "wb") as fh:
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)

        image_path = download_path

    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
}

def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format='PNG')
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


def resize_image(image_path):
    img = Image.open(image_path)
    width, height = img.size
    img = img.resize((int(width / 2), int(height / 2)))
    new_image_path = f"resized_{image_path}"
    img.save(new_image_path)
    return new_image_path

global_idx = 2695
cnt_sample = "None"
cnt_image_path = "None"


def open_image(file_path):
    if file_path:
        try:

            img = Image.open(file_path)

            width, height = img.size
            max_size = 400

            ratio = min(max_size / width, max_size / height)
            new_size = (int(width * ratio), int(height * ratio))

            img = img.resize(new_size, Image.LANCZOS)

            photo = ImageTk.PhotoImage(img)

            image_label.config(image=photo)
            image_label.image = photo

            # root.geometry(f"{new_size[0] + 50}x{new_size[1] + 100}")

            return True
        except Exception as e:
            print(f"Failed to load the image: {e}")
            return False
    return False
def refresh():
    refresh_frame(dataset[global_idx])

def refresh_frame(sample, image_name='AUTO', image_dir=os.path.join(images_dir,"/original")):
    global cnt_image_path, cnt_sample
    image_dir = os.path.join(image_dir, sample['name'])
    if image_name == 'AUTO':
        image_name = sample['name'] + '.jpeg'
        cnt_image_path = os.path.join(image_dir, image_name)
        text_entry_image_edit.delete(0, tk.END)
        text_entry_image_edit.insert(0, sample['rule'])
    else:
        cnt_image_path = image_name
    cnt_sample = sample
    text_entry_question.delete(0, tk.END)
    text_entry_question.insert(0,sample['question'])
    text_entry_truth.delete(0, tk.END)
    text_entry_truth.insert(0,sample['ground_truth'])
    text_entry_hallu.delete(0, tk.END)
    text_entry_hallu.insert(0,sample['hallu_answer'])
    text_entry_image_name.delete(0, tk.END)
    text_entry_image_name.insert(0, os.path.basename(image_name))
    text_entry_number.delete(0, tk.END)
    text_entry_number.insert(0, str(global_idx))

    open_image(cnt_image_path)


def re_generate_image(sample):
    global cnt_image_path
    image_name = os.path.basename(cnt_image_path)
    dir_path = os.path.dirname(cnt_image_path)
    new_image_path = os.path.join(dir_path, "re_"+image_name)
    text_entry_image_name.delete(0, tk.END)
    text_entry_image_name.insert(0, "re_" + image_name)
    edit_image(f"data:image/jpeg;base64,{encode_image(cnt_image_path)}",text_entry_image_edit.get(),new_image_path,"doubao-seededit-3-0-i2i-250628")
    cnt_image_path = new_image_path
    refresh_frame(cnt_sample,cnt_image_path)

def re_generate_image_v2(sample):
    global cnt_image_path
    image_name = os.path.basename(cnt_image_path)
    dir_path = os.path.dirname(cnt_image_path)
    new_image_path = os.path.join(dir_path, "re_"+image_name)
    text_entry_image_name.delete(0, tk.END)
    text_entry_image_name.insert(0, "re_" + image_name)
    edit_image(f"data:image/jpeg;base64,{encode_image(cnt_image_path)}",text_entry_image_edit.get(),new_image_path,"gemini-2.5-flash-image-preview")
    cnt_image_path = new_image_path
    refresh_frame(cnt_sample,cnt_image_path)


def test(sample):

    message = [
        {
            "role": "user", "content": [
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{encode_image(cnt_image_path)}"}
            },
            {
                "type": "text",
                "text": text_entry_question.get()
            },
        ]
        }
    ]
    response = client.chat.completions.create(
        model="gemini-2.5-flash-lite",
        messages=message
    )
    response = response.choices[0].message.content
    test_response.delete("1.0", tk.END)
    test_response.insert("1.0", response)
    print("OVER")
def save_instance(sample,dataset_json,image_dir):
    global global_idx

    if not os.path.exists(dataset_json):
        with open(dataset_json, 'w', encoding='utf-8') as f:
            json.dump([],f)
    with open(dataset_json,'r', encoding='utf-8') as f:
        dataset = json.load(f)
    with open(cnt_image_path, 'rb') as src_file:
        content = src_file.read()

    # 写入目标文件
    with open(os.path.join(image_dir,text_entry_image_name.get()), 'wb') as dst_file:
        dst_file.write(content)
    dataset.append({
        'image_path': os.path.join(image_dir,text_entry_image_name.get()),
        'item': "Undefined",
        'category': ["Undefined"],
        'question': text_entry_question.get(),
        'ground_truth': text_entry_truth.get(),
        'hallu_answer': text_entry_hallu.get(),
        'edit_command': text_entry_image_edit.get(),
        'original_sample': sample
    })
    with open(dataset_json, 'w', encoding='utf-8') as f:
        json.dump(dataset, f, ensure_ascii=False, indent=2)

def next():
    global global_idx
    global_idx += 1
    refresh()

# before use application, you should first generate normal images.
with open(ccs_content_file,'r', encoding='utf-8') as f:
    dataset = json.load(f)
dataset = [j for i in dataset for j in i]
root = tk.Tk()
root.geometry("1000x800")


style = ttk.Style()
style.configure("TButton", padding=6, relief="flat", background="#ccc")
style.configure("TFrame", background="#f0f0f0")
style.configure("TLabel", background="#f0f0f0")

main_frame = ttk.Frame(root, padding=20)
main_frame.pack(fill=tk.BOTH, expand=True)

image_frame = ttk.LabelFrame(main_frame, text="Preview", padding=10)
image_frame.grid(row=0, column=0, columnspan=2, sticky="nsew", pady=(0, 15))

image_label = tk.Label(image_frame, bg="white")
image_label.pack(fill="both", expand=True, padx=10, pady=10)


input_frame = ttk.LabelFrame(main_frame, text="Input", padding=10)
input_frame.grid(row=1, column=0, sticky="nsew", padx=(0, 10))

def create_labeled_entry(parent, label_text, default_text, row):
    ttk.Label(parent, text=label_text).grid(row=row, column=0, sticky="w", pady=(5, 0))
    entry = ttk.Entry(parent, width=30)
    entry.grid(row=row, column=1, sticky="ew", pady=(5, 0), padx=(5, 0))
    entry.insert(0, default_text)
    return entry

text_entry_question = create_labeled_entry(input_frame, "Question:", "question", 0)
text_entry_truth = create_labeled_entry(input_frame, "Ground Truth:", "ground_truth", 1)
text_entry_hallu = create_labeled_entry(input_frame, "Hallu Answer:", "hallu_answer", 2)
text_entry_image_name = create_labeled_entry(input_frame, "Image_name:", "image_name", 3)
text_entry_image_edit = create_labeled_entry(input_frame, "Edit:", "edit content", 4)
text_entry_number = create_labeled_entry(input_frame, "ID:", "0", 5)


input_frame.columnconfigure(1, weight=1)


action_frame = ttk.LabelFrame(main_frame, text="Edit", padding=10)
action_frame.grid(row=1, column=1, sticky="nsew")
test_response = tk.Text(action_frame, wrap=tk.WORD, height=8, width=40)
test_response.pack(fill=tk.X, pady=(0, 30))
button_frame = ttk.Frame(action_frame)
button_frame.pack(fill=tk.X)
buttons = [
    ("Regenerate", lambda: re_generate_image(cnt_sample)),
    ("Regenerate(4)", lambda: re_generate_image_v2(cnt_sample)),
    ("Test", lambda: test(cnt_sample)),
    ("Save", lambda: save_instance(cnt_sample, os.path.join(images_dir,"dataset2.json"), os.path.join(images_dir,"/saved"))),
    ("Next", next)
]

for i, (text, command) in enumerate(buttons):
    btn = ttk.Button(button_frame, text=text, command=command)
    btn.grid(row=0, column=i, padx=5, sticky="ew")
    button_frame.columnconfigure(i, weight=1)
status_bar = ttk.Frame(root, height=20)
status_bar.pack(fill=tk.X, side=tk.BOTTOM)
status_label = ttk.Label(status_bar, text="Done.")
status_label.pack(side=tk.LEFT, padx=10)
main_frame.rowconfigure(0, weight=1)
main_frame.rowconfigure(1, weight=1)
main_frame.columnconfigure(0, weight=1)
main_frame.columnconfigure(1, weight=1)
root.mainloop()

