import os
import json
import base64
import io
import argparse
from PIL import Image
import openai
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
from utils import extract_instructions, infer_with_DiT, roi_localization, fusion, layout_change


def encode_image_to_datauri(path, size=(512, 512)):
    with Image.open(path).convert('RGB') as img:
        img = img.resize(size, Image.LANCZOS)
        buffer = io.BytesIO()
        img.save(buffer, format='PNG')
    b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
    return b64
    # return f"data:image/png;base64,{b64}"


@retry(
    reraise=True,
    wait=wait_exponential(min=1, max=60),
    stop=stop_after_attempt(6),
    retry=retry_if_exception_type((openai.error.RateLimitError, openai.error.APIError))
)
def cot_with_gpt(image_uri, instruction):
    response = openai.ChatCompletion.create(
        model="gpt-4o",   
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f'''
                    Now you are an expert in image editing. Based on the given single image, what atomic image editing instructions should be if the user wants to {instruction}? Let's think step by step. 
                    Atomic instructions include 13 categories as follows:
                    - Add: e.g.: add a car on the road
                    - Remove: e.g.: remove the sofa in the image
                    - Color Change: e.g.: change the color of the shoes to blue
                    - Material Change: e.g.: change the material of the sign like stone
                    - Action Change: e.g.: change the action of the boy to raising hands
                    - Expression Change: e.g.: change the expression to smile
                    - Replace: e.g.: replace the coffee with an apple
                    - Background Change: e.g.: change the background into forest
                    - Appearance Change: e.g.: make the cup have a floral pattern
                    - Move: e.g.: move the plane to the left
                    - Resize: e.g.: enlarge the clock
                    - Tone Transfer: e.g.: change the weather to foggy
                    - Style Change: e.g.: make the style of the image to cartoon
                    Respond *only* with a numbered list.  
                    Each line must begin with the category in square brackets, then the instruction. Please strictly follow the atomic categories.
                    The operation (what) and the target (to what) are crystal clear.  
                    Do not split replace to add and remove.
                    For example:
                    “1. [Add] add a car on the road\n
                    2. [Color Change] change the color of the shoes to blue\n
                    3. [Move] move the lamp to the left\n"
                    Do not include any extra text, explanations, JSON or markdown—just the list.
                    '''},
                    {
                        "type": "image_url",   
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image_uri}"
                        }
                    },
                ],
            }
        ],
        max_tokens=300,
    )
    text = response.choices[0].message.content.strip()
    # print(text)

    categories, instructions = extract_instructions(text)
    return categories, instructions


def main():
    parser = argparse.ArgumentParser(description="Evaluate single image + instruction using GPT-4o")
    parser.add_argument("image_path", help="Path to input image")
    parser.add_argument("prompt", help="Original instruction")
    args = parser.parse_args()

    openai.api_key = "YOUR_API_KEY"

    if not openai.api_key:
        raise ValueError("OPENAI_API_KEY environment variable not set.")
    
    ###########################################
    ###         CoT -> instructions         ###
    ###########################################

    uri = encode_image_to_datauri(args.image_path)
    categories, instructions = cot_with_gpt(uri, args.prompt)

    ###########################################
    ###      Neural Program Interpreter     ###
    ###########################################
    image = args.image_path
    for i in range(len(categories)):
        category = categories[i]
        instruction = instructions[i]
        if category in ('Add', 'Remove', 'Replace', 'Action Change', 'Move', 'Resize'):
            if category in ('Add', 'Remove', 'Replace'):
                ### RoI Localization
                mask_image = roi_localization(image, instruction)
                ### RoI Inpainting
                edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction)
            elif category == 'Action Change':
                ### RoI Localization
                mask_image = roi_localization(image, instruction)
                ### RoI Inpainting
                edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction) # inpainted bg
                ### RoI Editing
                changed_action = infer_with_DiT('RoI Editing', image, instruction) # action change
                fusion_image = fusion(edited_image, changed_action)
                ### RoI Compositioning
                edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction)
            elif category in ('Move', 'Resize'):
                ### RoI Localization
                mask_image = roi_localization(image, instruction)
                ### RoI Inpainting
                edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction) # inpainted bg
                changed_instance = layout_change(image, instruction) # move/resize
                fusion_image = fusion(edited_image, changed_instance)
                ### RoI Compositioning
                edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction)
      
        elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
            ### RoI Editing
            edited_image = infer_with_DiT('RoI Editing', image, instruction)

        elif category in ('Tone Transfer', 'Style Change'):
            ### Global Transformation
            edited_image = infer_with_DiT('Global Transformation', image, instruction)
        
        else:
            raise ValueError(f"Invalid category: '{category}'")
        
        image = edited_image


if __name__ == "__main__":
    main()
