import os
import sys
sys.path.append("../SLD")
sys.path.append("")
from dotenv import load_dotenv

load_dotenv()



import json
from tqdm import tqdm
import copy
import shutil
import random
import numpy as np
import argparse
import configparser
from PIL import Image
import cv2


from huggingface_hub import hf_hub_download
import torch
import diffusers

# Libraries heavily borrowed from LMD
import models
from models import sam
from utils import parse, utils

# SLD specific imports
from sld.detector import OWLVITV2Detector
from sld.sdxl_refine import sdxl_refine
from sld.utils import get_all_latents, run_sam, run_sam_postprocess, resize_image
from sld.llm_template import spot_object_template, spot_difference_template, image_edit_template
from sld.llm_chat_update import get_key_objects, get_updated_layout, get_update_prompt
from sld.llm_template import spot_difference_template_FoR2
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from torchvision import transforms
from diffusers import ControlNetModel
# from diffusers import StableDiffusion3ControlNetPipeline
# from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel

# 0.20.2

# SLD-FoR
from visual_module import OrientationModule, DepthModule

# Code borrowed from SLD
# Operation #1: Addition (The code is in sld/image_generator.py)

# Operation #2: Deletion (Preprocessing region mask for removal)
def get_remove_region(entry, remove_objects, move_objects, preserve_objs, models, config):
    """Generate a region mask for removal given bounding box info."""

    image_source = np.array(Image.open(entry["output"][-1]))
    H, W, _ = image_source.shape

    # if no remove objects, set zero to the whole mask
    if (len(remove_objects) + len(move_objects)) == 0:
        remove_region = np.zeros((W // 8, H // 8), dtype=np.int64)
        return remove_region

    # Otherwise, run the SAM segmentation to locate target regions
    remove_items = remove_objects + [x[0] for x in move_objects]
    remove_mask = np.zeros((H, W, 3), dtype=bool)
    for obj in remove_items:
        masks = run_sam(bbox=obj[1], image_source=image_source, models=models, torch_device=entry["device"])
        remove_mask = remove_mask | masks

    # Preserve the regions that should not be removeda
    preserve_mask = np.zeros((H, W, 3), dtype=bool)
    for obj in preserve_objs:
        masks = run_sam(bbox=obj[1], image_source=image_source, models=models, torch_device=entry["device"])
        preserve_mask = preserve_mask | masks
    # Process the SAM mask by averaging, thresholding, and dilating.
    preserve_region = run_sam_postprocess(preserve_mask, H, W, config)
    remove_region = run_sam_postprocess(remove_mask, H, W, config)

    # TODO: This does not work with inside relation at all
    remove_region = np.logical_and(remove_region, np.logical_not(preserve_region))
    return remove_region

# Operation #3: Repositioning (Preprocessing latent)
def get_repos_info(entry, move_objects, models, config):
    """
    Updates a list of objects to be moved / reshaped, including resizing images and generating masks.
    * Important: Perform image reshaping at the image-level rather than the latent-level.
    * Warning: For simplicity, the object is not positioned to the center of the new region...
    """

    # if no remove objects, set zero to the whole mask
    if not move_objects:
        return move_objects
    image_source = np.array(Image.open(entry["output"][-1]))
    H, W, _ = image_source.shape
    inv_seed = int(config.get("SLD", "inv_seed"))

    new_move_objects = []
    for item in move_objects:
        new_img, obj = resize_image(image_source, item[0][1], item[1][1])
        old_object_region = run_sam_postprocess(run_sam(obj, new_img, models, torch_device=entry["device"]), H, W, config).astype(np.bool_)
        all_latents, _ = get_all_latents(new_img, models, inv_seed)
        new_move_objects.append(
            [item[0][0], obj, item[1][1], old_object_region, all_latents]
        )

    return new_move_objects

# Operation #4: Attribute Modification (Preprocessing latent)
def get_attrmod_latent(entry, change_attr_objects, models, config):
    """
    Processes objects with changed attributes to generate new latents and the name of the modified objects.

    Parameters:
    entry (dict): A dictionary containing output data.
    change_attr_objects (list): A list of objects with changed attributes.
    models (Model): The models used for processing.
    inv_seed (int): Seed for inverse generation.

    Returns:
    list: A list containing new latents and names of the modified objects.
    """
    if len(change_attr_objects) == 0:
        return []
    from diffusers import StableDiffusionDiffEditPipeline
    from diffusers import DDIMScheduler, DDIMInverseScheduler

    img = Image.open(entry["output"][-1])
    image_source = np.array(img)
    H, W, _ = image_source.shape
    inv_seed = int(config.get("SLD", "inv_seed"))

    # Initialize the Stable Diffusion pipeline
    pipe = StableDiffusionDiffEditPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
    )
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
    pipe.set_progress_bar_config(disable=True)
    # pipe.enable_model_cpu_offload()
    pipe.to(entry["device"])
    new_change_objects = []
    for obj in change_attr_objects:
        # Run diffedita
        old_object_region = run_sam_postprocess(run_sam(obj[1], image_source, models, torch_device=entry["device"]), H, W, config)
        old_object_region = old_object_region.astype(np.bool_)[np.newaxis, ...]

        new_object = obj[0].split(" #")[0]
        # "cat" -> "cat", "brown cat" -> "cat", "brown fire hydrant" -> "fire hydrant"
        base_object = new_object.split(" ")[-1] if len(new_object.split(" ")) <= 1 else " ".join(new_object.split(" ")[1:])
        mask_prompt = f"a {base_object}"
        new_prompt = f"a {new_object}"

        facing_attribute = {"front": "facing toward the camera", 
                            "left": "facing to the left of the image", 
                            "back": "facing away from the camera", 
                            "right": "facing to the right of the image",
                            "forward-left" : "facing to the forward-left direction of the camera",
                            "backward-right" : "facing to the backward-left direction of the camera",
                            "forward-left" : "facing to the forward-right direction of the camera",
                            "backward-right" : "facing to the backward-right direction of the camera"}
        
        if obj[3].lower() in facing_attribute:
            new_prompt = f"a {new_object} {facing_attribute[obj[3].lower()]}"
            print(new_prompt)
        image_latents = pipe.invert(
            image=img,
            prompt=mask_prompt,
            inpaint_strength=float(config.get("SLD", "diffedit_inpaint_strength")),
            generator=torch.Generator(device=os.getenv("MAIN_CUDA")).manual_seed(inv_seed),
        ).latents
        
        image = pipe(
            prompt=new_prompt,
            mask_image=old_object_region,
            image_latents=image_latents,
            guidance_scale=float(config.get("SLD", "diffedit_guidance_scale")),
            inpaint_strength=float(config.get("SLD", "diffedit_inpaint_strength")),
            generator=torch.Generator(device=os.getenv("MAIN_CUDA")).manual_seed(inv_seed),
            negative_prompt="",
        ).images[0]

        all_latents, _ = get_all_latents(np.array(image), models, inv_seed)
        new_change_objects.append(
            [
                old_object_region[0],
                all_latents,
            ]
        )
    return new_change_objects

# Operation #5: Reference Depth
def get_reference_depth(entry, change_depth_object, models, depth_models, config):
    
    from torchvision import transforms
    from diffusers import StableDiffusionControlNetPipeline
    from diffusers import DDIMScheduler, DDIMInverseScheduler

    def change_depth(old, init, new_depth):
        # avg-depth
        new_depth = (old - init) + new_depth
        return np.clip(new_depth, 0, 255.0) # Control min to 0 and max to 255

    def generate_new_depth(depth_map, object_mask, init_depth, new_depth):
        new_depth_map = np.array(depth_map).astype(np.float64).copy()

        new_depth_map[object_mask > 0] = change_depth(new_depth_map[object_mask > 0], init_depth, new_depth)
        new_depth_map[object_mask <= 0] = 0
        
        return new_depth_map.astype(np.uint8)
    
    if len(change_depth_object) == 0:
        return [], None

    image_source = np.array(Image.open(entry["output"][-1]))
    H, W, _ = image_source.shape
    _, depth_map = depth_models.extract_depth(image_source)

    controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16
    )

    # Load the SD‑2.1 base model
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        torch_dtype=torch.float16
    )

    pipe.set_progress_bar_config(disable=True)
    # pipe.enable_model_cpu_offload()
    pipe.to(entry["device"])

    depth_edit_latent = []

    for old_item, new_item in change_depth_object:

        segment_sam = run_sam(old_item[1], image_source, models, torch_device=entry["device"])
        _, object_mask_org, _ = depth_models.get_object_depth(segment_sam, depth_map)

        new_depth_map = np.zeros((H, W)).astype(np.uint8)

        new_edit_depth = generate_new_depth(depth_map, object_mask_org, old_item[2] * 255, new_item[2] * 255)

        x, y, w, h = [int(coor * (H if idx % 2 else W)) for idx, coor in enumerate(old_item[1])]

        region = new_edit_depth[y:y+h, x:x+w].copy()

        x_new, y_new, w_new, h_new = [int(coor * (H if idx % 2 else W)) for idx, coor in enumerate(new_item[1])]

        # Prevent over bounding box
        w_new = min(W - x_new, w_new)
        h_new = min(H - y_new, h_new)

        region_resized = cv2.resize(region, (w_new, h_new), interpolation=cv2.INTER_LINEAR)

        new_depth_map[y_new:y_new+h_new, x_new:x_new+w_new] = region_resized

        transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),  # Converts to [1, H, W], normalized to 0–1
            # transforms.Lambda(lambda x: x.repeat(1, 1, 1))  # Convert to 3 channels
        ])

        depth_tensor = transform(Image.fromarray(new_depth_map, mode="L").convert("RGB")).unsqueeze(0) 


        new_object = new_item[0].split(" #")[0]
        # "cat" -> "cat", "brown cat" -> "cat", "brown fire hydrant" -> "fire hydrant"
        new_prompt = f"a {new_object}"

        edited_image = pipe(prompt=new_prompt, 
                            image=depth_tensor,
                            controlnet_conditioning_scale=1.0)
        
        new_image = edited_image[0][0]

        object_region = run_sam_postprocess(run_sam(new_item[1], np.array(new_image), models, torch_device=entry["device"]), H, W, config)
        object_region = object_region.astype(np.bool_)[np.newaxis, ...]

        all_latents, _ = get_all_latents(np.array(new_image), models, 1)

        depth_edit_latent.append([object_region[0], all_latents])
    
    return depth_edit_latent, None

def spot_objects(prompt, data, config=None):
    # If the object list is not available, run the LLM to spot objects
    if data.get("llm_parsed_prompt") is None:
        questions = f"User Prompt: {prompt}\nReasoning:\n"
        message = spot_object_template + questions
        results = get_key_objects(message, config)
        return results[0]  # Extracting the object list
    else:
        return data["llm_parsed_prompt"]
    
def spot_differences(prompt, det_results, data, mode="self_correction", config=None):
    #  "llm_layout_suggestions": [["woman #1", [0.542, 0.209, 0.237, 0.601], 0.525, null], ["clown #1", [0.206, 0.153, 0.244, 0.81], 0.748, null]]
    # 
    if data["llm_layout_suggestions"] is None:
        print(prompt, det_results)
        questions = (
            f"User Prompt: {prompt}\nCurrent Objects: {det_results}\nReasoning:\n"
        )
        if mode == "self_correction":
            message = spot_difference_template_FoR2 + questions
        else:
            message = image_edit_template + questions
        llm_suggestions = get_updated_layout(message, config)
        # print(llm_suggestions)
        return llm_suggestions[0]
    else:
        return data["llm_layout_suggestions"]

def spot_camera_relation(prompt, det_results, full_rules=True):
    return get_update_prompt(prompt, det_results, full_rules=full_rules)

# Update depth and orientation
def update_det_results(entry, det_results, depth_models, orientation_models, models):

    image_source = np.array(Image.open(entry["output"][-1]))
    H, W, _ = image_source.shape
    depth_image, depth_map = depth_models.extract_depth(image_source)

    new_det_results = []

    for item in det_results:
        bbox = item[1]
        _, _, avg_depth = depth_models.get_object_depth(run_sam(bbox, image_source, models, torch_device=entry["device"]), depth_map)
        avg_depth = round(avg_depth, 3)    
        orientation = orientation_models.predict_orientation(" ".join(item[0].split()[:-1]), item[1], entry["output"][-1])

        new_obj_info = list(copy.deepcopy(item)) + [avg_depth, orientation]

        new_det_results.append(tuple(new_obj_info))

    return new_det_results

def correction(entry, add_objects, move_objects,
    remove_region, change_attr_objects, change_depth_object_new,
    reference_depth, depth_model, controlnet,
    models, config, image_generator, depth_guidece_scale=4):
    spec = {
        "add_objects": add_objects,
        "move_objects": move_objects,
        "prompt": entry["instructions"],
        "remove_region": remove_region,
        "change_objects": change_attr_objects,
        "change_depth_objects": change_depth_object_new,
        "reference_depth": reference_depth,
        "all_objects": entry["llm_suggestion"],
        "bg_prompt": entry["bg_prompt"],
        "extra_neg_prompt": entry["neg_prompt"],
    }
    image_source = np.array(Image.open(entry["output"][-1]))
    # Background latent preprocessing
    all_latents, _ = get_all_latents(image_source, models, int(config.get("SLD", "inv_seed")))
    ret_dict = image_generator.run(
        spec,
        fg_seed_start=int(config.get("SLD", "fg_seed")),
        bg_seed=int(config.get("SLD", "bg_seed")),
        bg_all_latents=all_latents,
        frozen_step_ratio=float(config.get("SLD", "frozen_step_ratio")),
        depth_model=depth_model,
        depth_guidece_scale=depth_guidece_scale,
        control_net=controlnet
    )
    return ret_dict

def main(args):
    # Pre-load all models
    config = configparser.ConfigParser()
    config.read(args.config)

    cur_device = os.getenv("MAIN_CUDA") if torch.cuda.is_available() else "cpu"

    
    # Load models
    models.sd_key = "gligen/diffusers-generation-text-box"
    models.sd_version = "sdv1.4"
    diffusion_scheduler = None

    models.model_dict = models.load_sd(
        key=models.sd_key,
        use_fp16=False,
        load_inverse_scheduler=True,
        scheduler_cls=diffusers.schedulers.__dict__[diffusion_scheduler]
        if diffusion_scheduler is not None
        else None,
    )

    sam_model_dict = sam.load_sam(torch_device=cur_device)
    models.model_dict.update(sam_model_dict)
    from sld import image_generator

    det = OWLVITV2Detector(device=cur_device)
    depth_det = DepthModule(cur_device=cur_device)
    orient_det = OrientationModule(peft_model_id=args.peft_orient, device=cur_device)
    # controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16, use_safetensors=True).to(cur_device)
    controlnet = None

    default_seed = int(config.get("SLD", "default_seed"))
    np.random.seed(default_seed)
    random.seed(default_seed)
    torch.manual_seed(default_seed)


    for setting in [0]:
        # args.version = setting
        data_file = "FILL YOUR DATA FILE HERE"

        with open(data_file) as fjson:
            dataset = json.load(fjson)["data"]
        

        for edit_round in range(args.st_round, args.round):
            results_round = []
            error_ids = []

            model_name = args.init_model_name
            if edit_round == 0:
                IMG_DIR = f"LMD_results/{model_name}"
            else:
                IMG_DIR = f"LMD_results/{model_name}_round{edit_round}"

            NEXT_IMG_DIR = f"LMD_results/{model_name}_round{edit_round + 1}"


            if args.use_depth_guide:
                suffix = "depth_guide"
            elif args.not_controlnet:
                suffix = "none"
            elif args.use_both_depth:
                suffix = "both"
            else:
                suffix = ""

            suffix += f"full_rules_update_prompt_with_SD1-5_init_{model_name}"

            llm_suggestions_pre_define = []

            # Change the directory name based on suffix
            if not args.detect_only:  
                IMG_DIR = IMG_DIR + suffix if edit_round != 0 else IMG_DIR
                NEXT_IMG_DIR = NEXT_IMG_DIR + suffix
                os.makedirs(NEXT_IMG_DIR, exist_ok=True)

            idx = 0 if args.version == 0 else 500
            for data in tqdm(dataset[:], desc=f"Editing Round{edit_round}"):
                idx += 1
                img_id = "FoREDIT_{:04d}.png".format(idx)
                
                initial_fname = os.path.join(IMG_DIR, img_id)
                # Resize image to 512,512
                img = cv2.imread(initial_fname)
                resized = cv2.resize(img, (512, 512))  # Resize to 512×512
                new_initial_fname = os.path.join(IMG_DIR, "resize_" + img_id)
                cv2.imwrite(new_initial_fname, resized)

                initial_fname = new_initial_fname
                output_fname  = os.path.join(NEXT_IMG_DIR, img_id)

                # Getting key objects
                prompt = data["prompt"]
                llm_parsed_prompt = spot_objects(prompt, data, config=config)


                # Setting up entry
                entry = {"instructions": prompt, "output": [initial_fname],
                "objects": llm_parsed_prompt["objects"], 
                "bg_prompt": llm_parsed_prompt["bg_prompt"],
                "neg_prompt": llm_parsed_prompt["neg_prompt"],
                "device": cur_device}


                # parameters for detection
                attr_threshold = float(config.get("SLD", "attr_detection_threshold"))
                prim_threshold = float(config.get("SLD", "prim_detection_threshold"))
                nms_threshold = float(config.get("SLD", "nms_threshold"))

                initial_det_results = det.run(prompt, entry["objects"], entry["output"][-1],
                                                attr_detection_threshold=attr_threshold, 
                                                prim_detection_threshold=prim_threshold,
                                                nms_threshold=nms_threshold)
                
                # Update with depth and orrientation
                det_results = update_det_results(entry, initial_det_results, depth_det, orient_det, models)
                # Getting layout suggestion from LLM
                # llm_suggestions = None
                if args.detect_only:
                    results_round.append({"id": "FoREDIT_{:04d}".format(idx), 
                                    "detection_result": copy.deepcopy(det_results),
                                    "llm_suggestion": None
                                    })
                    continue
                
              
                data["llm_layout_suggestions"] = None
               

                if edit_round != args.round - 1:
                    # Getting updated prompt for LLM-edit layout
                    update_prompt_LLM =  spot_camera_relation(prompt, det_results)
                    # Getting updated layout
                    llm_suggestions = spot_differences(update_prompt_LLM, det_results, data, mode="self_correction", config=config)
                else:
                    update_prompt_LLM = None
                    llm_suggestions = None

                # print(det_results)
                results_round.append({"id": "FoREDIT_{:04d}".format(idx), 
                                    "detection_result": copy.deepcopy(det_results),
                                    "llm_suggestion": copy.deepcopy(llm_suggestions)
                                    })
                
                if edit_round == args.round - 1:
                    print("Last round only evaluate detection.")
                    continue

                if llm_suggestions is None:
                    error_ids.append({"id": img_id, 
                                    "error_round": edit_round,
                                    "error_type": "LLM"})
                    shutil.copy(entry["output"][-1], output_fname)
                    continue
                
                entry["det_results"] = copy.deepcopy(det_results)
                entry["llm_suggestion"] = copy.deepcopy(llm_suggestions)

                (
                    preserve_objs,
                    deletion_objs,
                    addition_objs,
                    repositioning_objs,
                    attr_modification_objs,
                    change_depth_object
                ) = det.parse_list(det_results, llm_suggestions)
                total_ops = len(deletion_objs) + len(addition_objs) + len(repositioning_objs) + len(attr_modification_objs) + len(change_depth_object)
                if (total_ops == 0):
                    shutil.copy(entry["output"][-1], output_fname)
                    continue


                try:
                    print("Trying to edit")
                    deletion_region = get_remove_region(
                    entry, deletion_objs, repositioning_objs + change_depth_object, preserve_objs, models, config
                    )
                    repositioning_objs = get_repos_info(
                        entry, repositioning_objs, models, config
                    )
                    attr_modification_objs = get_attrmod_latent(
                        entry, attr_modification_objs, models, config
                    )

                    change_depth_object_new, reference_depth = get_reference_depth(entry, change_depth_object, models, depth_det, config)

                    
                    if args.use_depth_guide:
                        ret_dict = correction(
                            entry, addition_objs, repositioning_objs,
                            deletion_region, attr_modification_objs, change_depth_object_new,
                            reference_depth, depth_det, None,
                            models, config, image_generator, depth_guidece_scale=400
                        )
                    elif args.not_controlnet:
                        ret_dict = correction(
                            entry, addition_objs, repositioning_objs,
                            deletion_region, attr_modification_objs, [],
                            reference_depth, None, None,
                            models, config, image_generator, depth_guidece_scale=0
                        )
                    elif args.use_both_depth:
                        ret_dict = correction(
                            entry, addition_objs, repositioning_objs,
                            deletion_region, attr_modification_objs, change_depth_object_new,
                            reference_depth, depth_det, controlnet,
                            models, config, image_generator, depth_guidece_scale=1.5
                        )
                    else:
                        ret_dict = correction(
                            entry, addition_objs, repositioning_objs,
                            deletion_region, attr_modification_objs, change_depth_object_new,
                            reference_depth, None, controlnet,
                            models, config, image_generator, depth_guidece_scale=4
                        )
                    # Save an intermediate file without the SDXL refinement
                    print("Sucessfully edited.")
                    Image.fromarray(ret_dict.image).save(output_fname)
                    utils.free_memory()
                except Exception as e:
                    print("Editing Error:", e)
                    error_ids.append({"id": img_id, 
                                    "error_round": edit_round,
                                    "error_type": "correction (Prosibly Latent Operation of SLD)" + str(e)})
                    shutil.copy(entry["output"][-1], output_fname)
                    utils.free_memory()
                    continue
                
                
            if args.save_results:
                save_file_name=f"FoR_editing_results_V{args.version}_round-{edit_round}{suffix}.json"
                # if args.detect_only:
                #     if edit_round == 0:
                #         save_file_name = "results_GLIGEN.json" 
                #     else:
                #         save_file_name = "results_GLIGEN_SG.json" 
                print("Saving file at", save_file_name)
                json.dump({"results": results_round, "error" : error_ids}, open(save_file_name, "w"), indent=3)




if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_ver", type=str, default="")
    parser.add_argument("--version", type=int, default=0)
    parser.add_argument("--st_round", type=int, default=0)
    parser.add_argument("--round", type=int, default=1)
    parser.add_argument("--config", type=str, default="../benchmark_FoR_config.ini")
    parser.add_argument("--use_depth_guide", type=bool, default=False)
    parser.add_argument("--not_controlnet", type=bool, default=False)
    parser.add_argument("--use_both_depth", type=bool, default=False)
    parser.add_argument("--detect_only", help="Only detect the results", action="store_true")
    parser.add_argument("--save_results", type=bool, default=False)
    parser.add_argument("--peft_orient", type=str, default="")
    parser.add_argument("--init_model_name", type=str, default="GPT")
    parser.add_argument("--predefine_llm_suggestion", help="Use pre-define layout for round 1", action="store_true")
    args = parser.parse_args()
    
    main(args)