import os
import sys
sys.path.append("../SLD")
sys.path.append("")
from dotenv import load_dotenv

load_dotenv()

import json
import copy
import shutil
import random
import numpy as np
import argparse
import configparser
from PIL import Image
import logging
import cv2

import torch
import diffusers
from tqdm import tqdm

# Libraries heavily borrowed from LMD
import models
from models import sam
from utils import parse, utils

# SLD specific imports
from sld.detector import OWLVITV2Detector
from sld.sdxl_refine import sdxl_refine
from sld.utils import get_all_latents, run_sam, run_sam_postprocess, resize_image
from sld.llm_template import spot_object_template, spot_difference_template, image_edit_template
from sld.llm_chat import get_key_objects, get_updated_layout
from eval.eval import eval_prompt, Evaluator
from eval.lmd import get_lmd_prompts

os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Configure logging to include a console handler
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)

# Function to change the file handler
def set_file_handler(log_file_name):
    logger = logging.getLogger()  # Get the root logger
    for handler in logger.handlers[:]:  # Remove all handlers
        logger.removeHandler(handler)
    logger.addHandler(console_handler)  # Add back the console handler
    file_handler = logging.FileHandler(log_file_name, mode='w')  # Create a file handler for the new log file
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)  # Add the new file handler


# Operation #1: Addition (The code is in sld/image_generator.py)

# Operation #2: Deletion (Preprocessing region mask for removal)
def get_remove_region(entry, remove_objects, move_objects, preserve_objs, models, config):
    """Generate a region mask for removal given bounding box info."""

    image_source = np.array(Image.open(entry["output"][-1]))
    H, W, _ = image_source.shape

    # if no remove objects, set zero to the whole mask
    if (len(remove_objects) + len(move_objects)) == 0:
        remove_region = np.zeros((W // 8, H // 8), dtype=np.int64)
        return remove_region

    # Otherwise, run the SAM segmentation to locate target regions
    remove_items = remove_objects + [x[0] for x in move_objects]
    remove_mask = np.zeros((H, W, 3), dtype=bool)
    for obj in remove_items:
        masks = run_sam(bbox=obj[1], image_source=image_source, models=models)
        remove_mask = remove_mask | masks

    # Preserve the regions that should not be removed
    preserve_mask = np.zeros((H, W, 3), dtype=bool)
    for obj in preserve_objs:
        masks = run_sam(bbox=obj[1], image_source=image_source, models=models)
        preserve_mask = preserve_mask | masks
    # Process the SAM mask by averaging, thresholding, and dilating.
    preserve_region = run_sam_postprocess(preserve_mask, H, W, config)
    remove_region = run_sam_postprocess(remove_mask, H, W, config)
    remove_region = np.logical_and(remove_region, np.logical_not(preserve_region))
    return remove_region


# Operation #3: Repositioning (Preprocessing latent)
def get_repos_info(entry, move_objects, models, config):
    """
    Updates a list of objects to be moved / reshaped, including resizing images and generating masks.
    * Important: Perform image reshaping at the image-level rather than the latent-level.
    * Warning: For simplicity, the object is not positioned to the center of the new region...
    """

    # if no remove objects, set zero to the whole mask
    if not move_objects:
        return move_objects
    image_source = np.array(Image.open(entry["output"][-1]))
    H, W, _ = image_source.shape
    inv_seed = int(config.get("SLD", "inv_seed"))

    new_move_objects = []
    for item in move_objects:
        new_img, obj = resize_image(image_source, item[0][1], item[1][1])
        old_object_region = run_sam_postprocess(run_sam(obj, new_img, models), H, W, config).astype(np.bool_)
        all_latents, _ = get_all_latents(new_img, models, inv_seed)
        new_move_objects.append(
            [item[0][0], obj, item[1][1], old_object_region, all_latents]
        )

    return new_move_objects


# Operation #4: Attribute Modification (Preprocessing latent)
def get_attrmod_latent(entry, change_attr_objects, models, config):
    """
    Processes objects with changed attributes to generate new latents and the name of the modified objects.

    Parameters:
    entry (dict): A dictionary containing output data.
    change_attr_objects (list): A list of objects with changed attributes.
    models (Model): The models used for processing.
    inv_seed (int): Seed for inverse generation.

    Returns:
    list: A list containing new latents and names of the modified objects.
    """
    if len(change_attr_objects) == 0:
        return []
    from diffusers import StableDiffusionDiffEditPipeline
    from diffusers import DDIMScheduler, DDIMInverseScheduler

    img = Image.open(entry["output"][-1])
    image_source = np.array(img)
    H, W, _ = image_source.shape
    inv_seed = int(config.get("SLD", "inv_seed"))

    # Initialize the Stable Diffusion pipeline
    pipe = StableDiffusionDiffEditPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
    )
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
    pipe.set_progress_bar_config(disable=True)
    # pipe.enable_model_cpu_offload()
    pipe.to(entry["device"])

    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
    pipe.enable_model_cpu_offload()
    new_change_objects = []
    for obj in change_attr_objects:
        # Run diffedit
        old_object_region = run_sam_postprocess(run_sam(obj[1], image_source, models), H, W, config)
        old_object_region = old_object_region.astype(np.bool_)[np.newaxis, ...]

        new_object = obj[0].split(" #")[0]
        base_object = new_object.split(" ")[-1]
        mask_prompt = f"a {base_object}"
        new_prompt = f"a {new_object}"

        image_latents = pipe.invert(
            image=img,
            prompt=mask_prompt,
            inpaint_strength=float(config.get("SLD", "diffedit_inpaint_strength")),
            generator=torch.Generator(device=os.getenv("MAIN_CUDA")).manual_seed(inv_seed),
        ).latents
        
        image = pipe(
            prompt=new_prompt,
            mask_image=old_object_region,
            image_latents=image_latents,
            guidance_scale=float(config.get("SLD", "diffedit_guidance_scale")),
            inpaint_strength=float(config.get("SLD", "diffedit_inpaint_strength")),
            generator=torch.Generator(device=os.getenv("MAIN_CUDA")).manual_seed(inv_seed),
            negative_prompt="",
        ).images[0]

        all_latents, _ = get_all_latents(np.array(image), models, inv_seed)
        new_change_objects.append(
            [
                old_object_region[0],
                all_latents,
            ]
        )
    return new_change_objects


def correction(
    entry, add_objects, move_objects,
    remove_region, change_attr_objects,
    models, config, image_generator
):
    spec = {
        "add_objects": add_objects,
        "move_objects": move_objects,
        "prompt": entry["instructions"],
        "remove_region": remove_region,
        "change_objects": change_attr_objects,
        "change_depth_objects": [],
        "reference_depth": None,
        "all_objects": entry["llm_suggestion"],
        "bg_prompt": entry["bg_prompt"],
        "extra_neg_prompt": entry["neg_prompt"],
    }
    image_source = np.array(Image.open(entry["output"][-1]))
    # Background latent preprocessing
    all_latents, _ = get_all_latents(image_source, models, int(config.get("SLD", "inv_seed")))
    ret_dict = image_generator.run(
        spec,
        fg_seed_start=int(config.get("SLD", "fg_seed")),
        bg_seed=int(config.get("SLD", "bg_seed")),
        bg_all_latents=all_latents,
        frozen_step_ratio=float(config.get("SLD", "frozen_step_ratio")),
    )
    return ret_dict

def spot_objects(prompt, data, config=None):
    # If the object list is not available, run the LLM to spot objects
    if data.get("llm_parsed_prompt") is None:
        questions = f"User Prompt: {prompt}\nReasoning:\n"
        message = spot_object_template + questions
        results = get_key_objects(message, config)
        return results[0]  # Extracting the object list
    else:
        return data["llm_parsed_prompt"]
    
def spot_differences(prompt, det_results, data, mode="self_correction", config=None):
    #  "llm_layout_suggestions": [["woman #1", [0.542, 0.209, 0.237, 0.601], 0.525, null], ["clown #1", [0.206, 0.153, 0.244, 0.81], 0.748, null]]
    # 
    if data["llm_layout_suggestions"] is None:
        # print(prompt, det_results)
        questions = (
            f"User Prompt: {prompt}\nCurrent Objects: {det_results}\nReasoning:\n"
        )
        if mode == "self_correction":
            message = spot_difference_template + questions
        else:
            message = image_edit_template + questions
        llm_suggestions = get_updated_layout(message, config)
        # print(llm_suggestions)
        return llm_suggestions[0]
    else:
        return data["llm_layout_suggestions"]

def main(args):
    # Pre-load all models
    config = configparser.ConfigParser()
    config.read(args.config)

    cur_device = os.getenv("MAIN_CUDA") if torch.cuda.is_available() else "cpu"

    
    # Load models
    models.sd_key = "gligen/diffusers-generation-text-box"
    models.sd_version = "sdv1.4"
    diffusion_scheduler = None

    models.model_dict = models.load_sd(
        key=models.sd_key,
        use_fp16=False,
        load_inverse_scheduler=True,
        scheduler_cls=diffusers.schedulers.__dict__[diffusion_scheduler]
        if diffusion_scheduler is not None
        else None,
    )

    sam_model_dict = sam.load_sam(torch_device=cur_device)
    models.model_dict.update(sam_model_dict)
    from sld import image_generator

    det = OWLVITV2Detector(device=cur_device)

    default_seed = int(config.get("SLD", "default_seed"))
    np.random.seed(default_seed)
    random.seed(default_seed)
    torch.manual_seed(default_seed)


    for setting in [0]:
        # args.version = setting
        data_file = "FILL DATA FILE HERE"


        with open(data_file) as fjson:
            dataset = json.load(fjson)["data"]
        

        for edit_round in range(args.st_round, args.round):
            results_round = []
            error_ids = []

            model_name = args.init_model_name
            if edit_round == 0:
                IMG_DIR = f"LMD_results/{model_name}"
            else:
                IMG_DIR = f"LMD_results/{model_name}_round{edit_round}"

            NEXT_IMG_DIR = f"LMD_results/{model_name}_round{edit_round + 1}"

            suffix = f"original_{model_name}"
            # no_rules_
        
            # Change the directory name based on suffix
            if not args.detect_only:  
                IMG_DIR = IMG_DIR + suffix if edit_round != 0 else IMG_DIR
                NEXT_IMG_DIR = NEXT_IMG_DIR + suffix
                os.makedirs(NEXT_IMG_DIR, exist_ok=True)

            idx = 0 if args.version == 0 else 500
            for data in tqdm(dataset[:], desc=f"Editing Round{edit_round}"):
                idx += 1
                img_id = "FoREDIT_{:04d}.png".format(idx)
                
                initial_fname = os.path.join(IMG_DIR, img_id)
                # Resize image to 512,512
                img = cv2.imread(initial_fname)
                resized = cv2.resize(img, (512, 512))  # Resize to 512×512
                new_initial_fname = os.path.join(IMG_DIR, "resize_" + img_id)
                cv2.imwrite(new_initial_fname, resized)

                initial_fname = new_initial_fname
                output_fname  = os.path.join(NEXT_IMG_DIR, img_id)

                # Getting key objects
                prompt = data["prompt"]
                llm_parsed_prompt = spot_objects(prompt, data, config=config)


                # Setting up entry
                entry = {"instructions": prompt, "output": [initial_fname],
                "objects": llm_parsed_prompt["objects"], 
                "bg_prompt": llm_parsed_prompt["bg_prompt"],
                "neg_prompt": llm_parsed_prompt["neg_prompt"],
                "device": cur_device}


                # parameters for detection
                attr_threshold = float(config.get("SLD", "attr_detection_threshold"))
                prim_threshold = float(config.get("SLD", "prim_detection_threshold"))
                nms_threshold = float(config.get("SLD", "nms_threshold"))

                det_results = det.run(prompt, entry["objects"], entry["output"][-1],
                                                attr_detection_threshold=attr_threshold, 
                                                prim_detection_threshold=prim_threshold,
                                                nms_threshold=nms_threshold)
                
                # Getting layout suggestion from LLM
                if args.detect_only:
                    results_round.append({"id": "FoREDIT_{:04d}".format(idx), 
                                    "detection_result": copy.deepcopy(det_results),
                                    "llm_suggestion": None
                                    })
                    continue
                
                data["llm_layout_suggestions"] = None


                llm_suggestions = spot_differences(prompt, det_results, data, mode="self_correction", config=config)

                # print(det_results)
                results_round.append({"id": "FoREDIT_{:04d}".format(idx), 
                                    "detection_result": copy.deepcopy(det_results),
                                    "llm_suggestion": copy.deepcopy(llm_suggestions)
                                    })


                if llm_suggestions is None:
                    error_ids.append({"id": img_id, 
                                    "error_round": edit_round,
                                    "error_type": "LLM"})
                    shutil.copy(entry["output"][-1], output_fname)
                    continue
                
                entry["det_results"] = copy.deepcopy(det_results)
                entry["llm_suggestion"] = copy.deepcopy(llm_suggestions)

                (
                    preserve_objs,
                    deletion_objs,
                    addition_objs,
                    repositioning_objs,
                    attr_modification_objs,
                    _
                ) = det.parse_list(det_results, llm_suggestions)
                total_ops = len(deletion_objs) + len(addition_objs) + len(repositioning_objs) + len(attr_modification_objs)
                if (total_ops == 0):
                    shutil.copy(entry["output"][-1], output_fname)
                    continue
                try:
                    print("Trying to edit")
                    deletion_region = get_remove_region(
                    entry, deletion_objs, repositioning_objs, preserve_objs, models, config
                    )
                    repositioning_objs = get_repos_info(
                        entry, repositioning_objs, models, config
                    )
                    attr_modification_objs = get_attrmod_latent(
                        entry, attr_modification_objs, models, config
                    )

                    ret_dict = correction(
                        entry, addition_objs, repositioning_objs,
                        deletion_region, attr_modification_objs,
                        models, config, image_generator
                    )
                    # Save an intermediate file without the SDXL refinement
                    print("Sucessfully edited.")
                    Image.fromarray(ret_dict.image).save(output_fname)
                    utils.free_memory()

                except Exception as e:
                    print("Editing Error:", e)
                    error_ids.append({"id": img_id, 
                                    "error_round": edit_round,
                                    "error_type": "correction (Prosibly Latent Operation of SLD)" + str(e)})
                    shutil.copy(entry["output"][-1], output_fname)
                    utils.free_memory()
                    continue
                
                
            if args.save_results:
                save_file_name=f"FoR_editing_results_V{args.version}_round-{edit_round}{suffix}.json"
                # if args.detect_only:
                #     if edit_round == 0:
                #         save_file_name = "results_GLIGEN.json" 
                #     else:
                #         save_file_name = "results_GLIGEN_SG.json" 
                print("Saving file at", save_file_name)
                json.dump({"results": results_round, "error" : error_ids}, open(save_file_name, "w"), indent=3)




if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_ver", type=str, default="")
    parser.add_argument("--version", type=int, default=0)
    parser.add_argument("--st_round", type=int, default=0)
    parser.add_argument("--round", type=int, default=1)
    parser.add_argument("--config", type=str, default="../benchmark_FoR_config.ini")
    parser.add_argument("--use_depth_guide", type=bool, default=False)
    parser.add_argument("--not_controlnet", type=bool, default=False)
    parser.add_argument("--use_both_depth", type=bool, default=False)
    parser.add_argument("--detect_only", help="Only detect the results", action="store_true")
    parser.add_argument("--save_results", type=bool, default=False)
    parser.add_argument("--init_model_name", type=str, default="GPT")
    args = parser.parse_args()
    
    main(args)