import json
import base64
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import openai
from openai import OpenAI
import time
import re
import argparse
from PIL import Image
from io import BytesIO
import numpy as np
import shutil

client = OpenAI()   

def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', 
                        type=str, 
                        default="/root/flux/runs/eval_4o/test")
    parser.add_argument('--mapping_file', 
                        type=str, 
                        default="/root/data/BrushBench/mapping_file_list.json")
    parser.add_argument('--mask_key', 
                        type=str, 
                        default="inpainting_mask")
    parser.add_argument('--base_dir', 
                    type=str, 
                    default="/root/data/BrushBench")
    parser.add_argument('--image_dir', 
                    type=str, 
                    default="/root/flux/runs/baseline/eval")
    args = parser.parse_args()
    return args

def rle2mask(mask_rle, shape): # height, width
    starts, lengths = [np.asarray(x, dtype=int) for x in (mask_rle[0:][::2], mask_rle[1:][::2])]
    starts -= 1
    ends = starts + lengths
    binary_mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        binary_mask[lo:hi] = 1
    return binary_mask.reshape(shape)

def pil_image_to_base64(image: Image.Image, format: str = "PNG") -> str:
    buffer = BytesIO()
    image.save(buffer, format=format)
    buffer.seek(0)
    return base64.b64encode(buffer.read()).decode("utf-8")

def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def build_prompt(gt_image, mask_image, inpainting_image, caption):
    content = [
        {
            "type": "input_text",
            "text": "You are an expert in image aesthetics and quality assessment. "
                    "Your task is to evaluate the performance of our image inpainting model and assign scores to the inpainting images from 1 to 100. "
                    "You will receive one prompt describing the overall image, as well as three images: "
                    "(1) the original image to be inpainted, "
                    "(2) a binary inpainting mask (black region is the area to be inpainted), "
                    "and (3) the inpainted result generated by our model."
        },
        {
            "type": "input_text",
            "text": f"Prompt: {caption}"
        },
        {
            "type": "input_image",
            "image_url": f"data:image/jpeg;base64,{encode_image(gt_image)}"
        },
        {
            "type": "input_image",
            "image_url": f"data:image/jpeg;base64,{encode_image(mask_image)}"
        },
        {
            "type": "input_image",
            "image_url": f"data:image/jpeg;base64,{encode_image(inpainting_image)}"
        },
        {
            "type": "input_text",
            "text": "Evaluate the inpainting result based on the following three criteria:\n\n"
                    "1. Aesthetic Quality (0–40 pts):\n"
                    "   - Visual appeal: color harmony, composition, style coherence\n"
                    "   - Texture realism and naturalness\n\n"
                    "2. Structural Accuracy (0–30 pts):\n"
                    "   - Preservation of geometric structures and content continuity\n"
                    "   - Seamlessness at mask boundaries\n\n"
                    "3. Semantic Alignment (0–30 pts):\n"
                    "   - Faithfulness to the Text Prompt instructions\n"
                    "   - Contextual consistency of added or restored content\n\n"
                    "For each criterion, provide:\n"
                    "- A sub‑score (integer).\n"
                    "- A 1–2‑sentence justification.\n\n"
                    "Then compute the total score (1–100).\n"
                    "Return your result in the following JSON format:\n\n"
                    "{\n"
                    '  "Aesthetic Quality": {\n'
                    '    "score": <int 0–40>,\n'
                    '    "comment": "<brief justification>"\n'
                    "  },\n"
                    '  "Structural Accuracy": {\n'
                    '    "score": <int 0–30>,\n'
                    '    "comment": "<brief justification>"\n'
                    "  },\n"
                    '  "Semantic Alignment": {\n'
                    '    "score": <int 0–30>,\n'
                    '    "comment": "<brief justification>"\n'
                    "  },\n"
                    '  "Overall Score": <int 1–100>,\n'
                    '  "Summary": "<one‑sentence overall remark>"\n'
                    "}"
        }
    ]

    return content

args = args_parser()
temp_dir = os.path.join(f"/root/temp/eval_4o", "timestamp_" + str(int(time.time())) )
os.makedirs(temp_dir, exist_ok=True)
os.makedirs(os.path.join(args.save_path), exist_ok=True)
output_file = os.path.join(args.save_path, 'output.json')
refusal_file = os.path.join(args.save_path, 'refusal.txt')
     
num_threads = 64
max_retries = 3

write_lock_refusal = threading.Lock()
write_lock = threading.Lock()

processed_ids = set()
refusal_ids = set()

if os.path.exists(refusal_file):
    with open(refusal_file, 'r', encoding='utf-8') as f:
        for line in f:
            refusal_ids.add(line.strip())

stop_processing_event = threading.Event()

def process_line(data):
    if stop_processing_event.is_set():
        return
    try:
        entry_id = data['image_id']
        print(f"Processing entry {entry_id}")
        

        if entry_id in processed_ids:
            print(f"Entry {entry_id} already processed, skipping.")
            return
        
        if entry_id in refusal_ids:
            print(f"Entry {entry_id} is in refusal list, skipping.")
            return
        
        gt_image_path = os.path.join(args.base_dir, data["image"])
        inpainting_image_path = os.path.join(args.image_dir, data["image"]).replace(".jpg", "_blended.jpg")
        mask = data[args.mask_key]
        mask_image = 1 - rle2mask(mask,(512,512))[:,:,np.newaxis]
        mask_image = Image.fromarray(mask_image.repeat(3,-1)*255).convert("RGB").save(os.path.join(temp_dir, f"{entry_id}_mask.png"))   
        retry_attempts = 0
        success = False

        while retry_attempts < max_retries and not success:
            message = []
            message.append(
                {
                    'role': 'user',
                    'content': build_prompt(gt_image_path, os.path.join(temp_dir, f"{entry_id}_mask.png"), inpainting_image_path, data["caption"])
                }
            )
            try: 
                response = client.responses.create(
                    model='gpt-4.1',
                    input=message
                )
                output = response.output_text
                # if match 
                match = re.search(r'"Overall Score"\s*:\s*(\d+)', output, re.IGNORECASE)
                if match:
                    output = int(match.group(1))
                    text = response.output_text
                else:
                    raise ValueError(f"Invalid response format for entry {entry_id}: {output}")
                
                # write the output to the file
                with write_lock:
                    with open(output_file, 'a', encoding='utf-8') as outfile:
                        json.dump({'image_id': entry_id, 'score': output, "output": text}, outfile)
                        outfile.write('\n')

                success = True
                time.sleep(1)  # brief pause between requests

            except Exception as e:
                print(f"Error processing entry {entry_id}: {e}")
                retry_attempts += 1
                print(f"Retrying {entry_id} ({retry_attempts}/{max_retries})")
                time.sleep(60 * retry_attempts)

                if retry_attempts >= max_retries:
                    print(f"Max retries reached for entry {entry_id}. Adding to refusal list.")
                    refusal_ids.add(entry_id)
                    with write_lock_refusal:
                        with open(refusal_file, 'a', encoding='utf-8') as f:
                            f.write(f"{entry_id}\n")
                    return

    except json.JSONDecodeError:
        print("Invalid JSON format, skipping line.")
        return
    
if __name__ == "__main__":
    

    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as outfile:
            for line in outfile:
                try:
                    existing_data = json.loads(line)
                    processed_ids.add(existing_data['image_id'])
                except json.JSONDecodeError:
                    continue

    with open(args.mapping_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    

    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(process_line, line) for line in data]

        for future in as_completed(futures):
            future.result()  

    # remove the temporary directory
    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)

    print("Processing complete.")