import torch
import numpy as np
from diffusers.pipelines import FluxPipeline
from src.flux.condition import Condition
from PIL import Image
import argparse
import os
import re
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from scipy.ndimage import binary_dilation
import cv2
import openai

from src.flux.generate import generate, seed_everything

try:
    from mmengine.visualization import Visualizer
except ImportError:
    Visualizer = None
    print("Warning: mmengine is not installed, visualization is disabled.")

openai.api_key = "YOUR_API_KEY"

def infer_with_DiT(task, image, instruction):
    image = Image.open(image).convert("RGB").resize((512, 512))

    seed_everything(3407)

    if task == 'RoI Inpainting':
        local_lora_path = "local_path_for_inpainting"
        condition = Condition("scene", image, position_delta=(0, 0))
    elif task == 'RoI Editing':
        local_lora_path = "local_path_for_editing"
        condition = Condition("scene", image, position_delta=(0, -32))
    elif task == 'RoI Compositioning':
        local_lora_path = "local_path_for_compositioning"
        condition = Condition("scene", image, position_delta=(0, 0))
    elif task == 'Global Transformation':
        local_lora_path = "local_path_for_global"
        condition = Condition("scene", image, position_delta=(0, -32))
    else:
        raise ValueError(f"Invalid task: '{task}'")
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16
    )

    pipe = pipe.to("cuda")

    pipe.load_lora_weights(
    local_lora_path, 
    adapter_name="scene",
)
    result_img = generate(
        pipe,
        prompt=instruction,
        conditions=[condition],
        config_path = "train/config/scene_512.yaml",
        num_inference_steps=28,
        height=512,
        width=512,
    ).images[0]

    return result_img

def load_model(model_path):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True
    ).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return model, tokenizer

def extract_object_with_gpt(instruction):
    system_prompt = (
        "You are a helpful assistant that extracts the object or target being edited in an image editing instruction. "
        "Only return a concise noun phrase describing the object. "
        "Examples:\n"
        "- Input: 'Remove the dog' → Output: 'the dog'\n"
        "- Input: 'Replace the biggest bear with a tiger' → Output: 'the biggest bear'\n"
        "- Input: 'Change the action of the girl to riding' → Output: 'the girl'\n"
        "- Input: 'Add a red hat to the man on the left' → Output: 'the man on the left'\n"
        "- Input: 'Make the sky look cloudy' → Output: 'the sky'\n"
        "Now extract the object for this instruction:"
    )

    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": instruction}
            ],
            temperature=0.2,
            max_tokens=20,
        )
        object_phrase = response.choices[0].message['content'].strip().strip('"')
        print(f"Identified object: {object_phrase}")
        return object_phrase
    except Exception as e:
        print(f"GPT extraction failed: {e}")
        return instruction 
    
def get_masked(mask, image):
    if mask.shape[:2] != image.shape[:2]:
        raise ValueError("The dimensions of the mask and image do not match.")  
    masked_image = np.copy(image)
    masked_image[mask] = 0 
    return masked_image


def roi_localization(image, instruction):
    model, tokenizer = load_model("ByteDance/Sa2VA-8B")
    instruction = extract_object_with_gpt(instruction)
    instruction = f"<image>Please segment {instruction}."
    # (model, tokenizer, image_path, instruction, work_dir, dilate):
    img = Image.open(image).convert('RGB')
    print(f"Processing image: {os.path.basename(image)}, Instruction: {instruction}")

    result = model.predict_forward(
        image=img,
        text=instruction,
        tokenizer=tokenizer,
    )

    prediction = result['prediction']
    print(f"Model Output: {prediction}")

    if '[SEG]' in prediction and 'prediction_masks' in result:
        pred_mask = result['prediction_masks'][0]   
        pred_mask_np = np.squeeze(np.array(pred_mask))
        dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)  

        masked_img = get_masked(dilated_original_mask, img)
            
        return masked_img

    else:
        print("No valid mask found in the prediction.")
        return None

def fusion(rgb_image, rgba_image):
    if rgba_image.shape[:2] != rgb_image.shape[:2]:
        raise ValueError("The dimensions of the two images do not match.")

    r, g, b, a = rgba_image[:, :, 0], rgba_image[:, :, 1], rgba_image[:, :, 2], rgba_image[:, :, 3]

    alpha = a.astype(np.float32) / 255.0

    alpha_3d = np.stack([alpha] * 3, axis=2)
    
    blended = (rgba_image[:, :, :3].astype(np.float32) * alpha_3d + 
               rgb_image.astype(np.float32) * (1 - alpha_3d)).astype(np.uint8)
    
    return blended

def layout_change(bbox, instruction):
    response = openai.ChatCompletion.create(
        model="gpt-4o",   
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f'''
                    You are an intelligent bounding box editor. I will provide you with the current bounding boxes and the editing instruction. 
                    Your task is to generate the new bounding boxes after editing. 
                    The images are of size 512x512. The top-left corner has coordinate [0, 0]. The bottom-right corner has coordinnate [512, 512]. 
                    The bounding boxes should not overlap or go beyond the image boundaries. 
                    Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, bottom-right x coordinate, bottom-right y coordinate]). 
                    Do not add new objects or delete any object provided in the bounding boxes. Do not change the size or the shape of any object unless the instruction requires so.
                    Please consider the semantic information of the layout. 
                    When resizing, keep the bottom-left corner fixed by default. When swaping locations, change according to the center point. 
                    If needed, you can make reasonable guesses. Please refer to the examples below:
                     
                    Input bounding boxes: [('a green car', [21, 281, 232, 440]), ('a blue truck', [269, 283, 478, 443]), ('a red air balloon', [66, 8, 211, 143]), ('a bird', [296, 42, 439, 142])]
                    Editing instruction: Move the car to the right.
                    Output bounding boxes: [('a green car', [81, 281, 292, 440]), ('a blue truck', [269, 283, 478, 443]), ('a red air balloon', [66, 8, 211, 143]), ('a bird', [296, 42, 439, 142])]
                     
                    Input bounding boxes: [("bed", [50, 300, 450, 450]), ("pillow", [200, 200, 300, 230])]
                    Editing instruction: Move the pillow to the left side of the bed.
                    Output bounding boxes: [("bed", [50, 300, 450, 450]), ("pillow", [70, 270, 170, 300])]
                     
                    Input bounding boxes: [("sofa", [100, 300, 400, 400]), ("dog", [150, 250, 250, 300])]
                    Editing instruction: Enlarge the dog.
                    Output bounding boxes: [("sofa", [100, 300, 400, 400]), ("dog", [150, 225, 300, 300])]
                     
                    Input bounding boxes: [("chair", [100, 350, 200, 450]), ("lamp", [300, 200, 360, 300])]
                    Editing instruction: Swap the location of the chair and the lamp.
                    Output bounding boxes: [("chair", [280, 200, 380, 300]), ("lamp", [120, 350, 180, 450])]


                    Now, the current bounding boxes is {bbox}, the instruction is {instruction}. Let's think step by step, and output the edited layout.
                    '''},
                    # {
                    #     "type": "image_url",   
                    #     "image_url": {
                    #         "url": f"data:image/jpeg;base64,{image_uri}"
                    #     }
                    # },
                ],
            }
        ],
        max_tokens=1000,
    )
    result = response.choices[0].message.content.strip()
    return result