from typing import List, Optional
from PIL import Image
from models.attention_processor import CrossRegionAttnProcessor
from diffusers import DDIMScheduler
from segment_anything import sam_model_registry, SamPredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from compel import Compel
import numpy as np
import os
from utils.plot_utils import plot_results
import torch
import cv2
import torch.nn.functional as F


CLASSES = [
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
]


def iou(box1, box2):
    """
    Calculate the Intersection over Union (IoU) between two bounding boxes.
    
    Args:
        box1: A bounding box in the format [x_min, y_min, x_max, y_max].
        box2: A bounding box in the format [x_min, y_min, x_max, y_max].
    
    Returns:
        IoU value between the two bounding boxes.
    """

    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    inter_area = max(0, x2 - x1) * max(0, y2 - y1)

    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    union_area = box1_area + box2_area - inter_area

    return inter_area / union_area if union_area > 0 else 0


def get_bboxes_union(det_results, gt_bboxes, gt_labels, iou_threshold=0.5):
    """
    Combine detection results with ground truth boxes.
    
    Args:
        det_results: A dictionary containing keys 'boxes' and 'labels':
            - det_results['boxes']: List of detected bounding boxes, each in the format [x_min, y_min, x_max, y_max].
            - det_results['labels']: List of detected labels, corresponding to the boxes.
        gt_bboxes: A list of ground truth bounding boxes, each in the format [x_min, y_min, x_max, y_max].
        gt_labels: List of ground truth labels, corresponding to gt_bboxes.
        iou_threshold: IoU threshold to determine if boxes are considered a match, default is 0.5.

    Returns:
        TP_bboxes: Ground truth boxes that match with detections (True Positives).
        TP_labels: Labels of the matched ground truth boxes.
        false_bboxes: Bounding boxes that are false positives or missed detections.
        false_labels: Labels corresponding to false_bboxes.
    """
    det_boxes = np.array(det_results["boxes"])
    det_labels = np.array(det_results["labels"])

    TP_bboxes = []
    TP_labels = []

    FP_bboxes = []
    FP_labels = []

    FN_bboxes = []
    FN_labels = []
    FN_indices = []

    matched_dets = [False] * len(det_boxes)

    for idx, (gt_coords, gt_label) in enumerate(zip(gt_bboxes, gt_labels)):
        gt_coords = [int(x * 512) for x in gt_coords]

        best_match = None
        max_iou = 0
        best_match_index = -1

        for i, (det_box, det_label) in enumerate(zip(det_boxes, det_labels)):
            if det_label == gt_label and not matched_dets[i]: 
                current_iou = iou(det_box, gt_coords)
                if current_iou > max_iou:
                    max_iou = current_iou
                    best_match = det_box
                    best_match_index = i

        if best_match is not None and max_iou >= iou_threshold:
            TP_bboxes.append(best_match)
            TP_labels.append(gt_label)
            matched_dets[best_match_index] = True
        else:
            FN_bboxes.append(gt_coords)
            FN_labels.append(gt_label)
            FN_indices.append(idx)

    for i, matched in enumerate(matched_dets):
        if not matched:
            FP_bboxes.append(det_boxes[i])
            FP_labels.append(det_labels[i])

    TP_bboxes = np.array(TP_bboxes)

    FN_bboxes = np.array(FN_bboxes)
    FP_bboxes = np.array(FP_bboxes)
    return TP_bboxes, TP_labels, FP_bboxes, FP_labels, FN_bboxes, FN_labels, FN_indices


def dilate_with_no_overlap(gt_masks, kernel_size=3, iterations=1):
    gt_masks = np.array([x.cpu().numpy().squeeze() for x in gt_masks]).astype(np.uint8)
    N, H, W = gt_masks.shape
    dilated_masks = np.zeros_like(gt_masks, dtype=np.uint8)

    occupied_region = np.sum(gt_masks, axis=0).clip(0, 1)
    kernel = np.ones((kernel_size, kernel_size), np.uint8)

    for i in range(N):
        occupied_exclude_self = (occupied_region - gt_masks[i]).clip(0, 1)
        dilated = cv2.dilate(gt_masks[i], kernel, iterations=iterations)

        new_region = (dilated == 1) & (occupied_exclude_self == 0)
        dilated_masks[i][new_region] = 1
        occupied_region[new_region] = 1

    return dilated_masks


class ReConHelper:
    def __init__(
        self,
        pipe,
        enable_fast_sampling=True,
        enable_region_guided_rectification=True,
        enable_region_aligned_cross_attention=True,
        perception_steps=[0.75, 0.5, 0.25, 0.1],
        num_cache_steps=5,
        debug_mode=False,
        is_controlnet=True,
        device="cuda",
        debug_log_dir="./example_output/controlnet_recon",
    ):
        # is_controlnet = True will filter the original visual condition and employ region aligned cross attention for controlnet branch

        pipe.scheduler = DDIMScheduler.from_config(
            pipe.scheduler.config, subfolder="scheduler"
        )
        # self.pipe = pipe
        self.enable_fast_sampling = enable_fast_sampling
        self.enable_region_guided_rectification = enable_region_guided_rectification
        self.enable_region_aligned_cross_attention = (
            enable_region_aligned_cross_attention
        )
        self.is_controlnet = is_controlnet

        sam = sam_model_registry["vit_h"](checkpoint="ckpts/sam_vit_h_4b8939.pth").to(
            device=device
        )
        self.sam_predictor = SamPredictor(sam)

        model_id = "ckpts/grounding-dino-tiny"  # "IDEA-Research/grounding-dino-tiny"
        self.dino_processor = AutoProcessor.from_pretrained(model_id)
        self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
            model_id
        ).to(device)

        if enable_fast_sampling:
            from deepcache_extension import DeepCacheSDHelper

            deepcache_helper = DeepCacheSDHelper(pipe=pipe)
            deepcache_helper.set_params(
                cache_interval=3,
                cache_branch_id=0,
            )
        else:
            deepcache_helper = None
        self.deepcache_helper = deepcache_helper
        self.num_cache_steps = num_cache_steps
        self.perception_steps = perception_steps
        self.debug_mode = debug_mode
        self.debug_log_dir = debug_log_dir
        if debug_mode:
            os.makedirs(self.debug_log_dir, exist_ok=True)

        self.perception_text_feats = self.prepare_perception_text_feats(device)

        self.gt_bboxes, self.gt_labels, self.gt_masks = None, None, None

    def prepare_perception_text_feats(self, device):
        dino_text = ""
        for phrase in CLASSES:
            dino_text += phrase + ". "
        dino_text = dino_text.rstrip()

        text_feats = self.dino_processor(text=dino_text, return_tensors="pt").to(device)
        return text_feats

    def papare_init_masks_and_text_embedding(
        self, pil_image, prompt_embeds, device="cuda", control_image=None
    ):
        height, width = pil_image.size
        gt_bboxes, gt_labels = self.gt_bboxes, self.gt_labels  # !! required

        inputs = self.dino_processor(images=pil_image, return_tensors="pt").to(device)
        inputs.update(self.perception_text_feats)
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                if v.dtype in [torch.float32, torch.float16]:
                    inputs[k] = v.to(dtype=torch.float32, device=device)
                else:
                    inputs[k] = v.to(device)  # 不转换 dtype，仅转 device
        # self.dino_model = self.dino_model.float()
        # for module in self.dino_model.modules():
        #     if hasattr(module, 'float'):
        #         module.float()
        outputs = self.dino_model(**inputs)
        det_results = self.dino_processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=0.4,
            text_threshold=0.3,
            target_sizes=[pil_image.size[::-1]],
        )[0]

        det_results["scores"], det_results["boxes"] = (
            det_results["scores"].cpu().numpy(),
            det_results["boxes"].cpu().numpy(),
        )
        _, _, FP_bboxes, FP_labels, FN_bboxes, FN_labels, FN_indices = get_bboxes_union(
            det_results, gt_bboxes, gt_labels
        )

        if self.debug_mode:
            plot_results(
                pil_image,
                det_results["scores"],
                det_results["labels"],
                det_results["boxes"],
                f"{self.debug_log_dir}/origin_det.png",
            )
            plot_results(
                pil_image,
                None,
                gt_labels,
                gt_bboxes,
                f"{self.debug_log_dir}/origin_gt.png",
                de_normalze=True,
            )

        if len(gt_bboxes) == 0 and len(FP_bboxes) == 0:
            denormalized_boxes = np.array([])
        elif len(gt_bboxes) == 0:
            denormalized_boxes = np.array(FP_bboxes)
        elif len(FP_bboxes) == 0:
            denormalized_boxes = np.array(gt_bboxes) * height
        else:
            denormalized_boxes = np.concatenate(
                [np.array(gt_bboxes) * height, np.array(FP_bboxes)], axis=0
            )
        if len(denormalized_boxes) > 0:
            self.sam_predictor.set_image(np.array(pil_image))
            transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(
                torch.tensor(denormalized_boxes, device=self.dino_model.device),
                (height, width),
            ).to(device)

            seg_masks, _, _ = self.sam_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes.to(device),
                multimask_output=False,
            )
            gt_masks = seg_masks[: len(gt_bboxes)]
            FP_masks = seg_masks[len(gt_bboxes) :]
        else:
            gt_masks = []
            FP_masks = []

        # Generate fals_anno_supperession_mask
        if len(FP_masks) > 0:
            merged_dilated_FP_masks = np.ones((height, width), dtype=np.uint8)
            dilated_FP_masks = dilate_with_no_overlap(FP_masks, 7)
            for mask_i in dilated_FP_masks:
                merged_dilated_FP_masks[mask_i == 1] = 0

            merged_dilated_FP_masks = np.expand_dims(
                merged_dilated_FP_masks, axis=-1
            )  # (H, W, 1)
            merged_dilated_FP_masks = np.repeat(
                merged_dilated_FP_masks, 3, axis=-1
            )  # (H, W, 3)

            if control_image is not None:
                ori_control_image = control_image
                control_image = np.array(control_image) * merged_dilated_FP_masks
                control_image = Image.fromarray(control_image.astype(np.uint8))

            if self.debug_mode:
                Image.fromarray((merged_dilated_FP_masks * 255).astype(np.uint8)).save(
                    f"{self.debug_log_dir}/false_anno_supperession_mask.png"
                )
                control_image.save(f"{self.debug_log_dir}/filtered_control_image.png")
                ori_control_image.save(f"{self.debug_log_dir}/ori_control_image.png")

        dtype = prompt_embeds.dtype
        if prompt_embeds.ndim == 4:
            # if multiple prompt embbedings provided, then:
            dilated_gt_masks = dilate_with_no_overlap(gt_masks, 7)
            dilated_gt_masks = torch.tensor(dilated_gt_masks).to(
                dtype=dtype, device=prompt_embeds.device
            )
            prompt_embeds = {
                "prompt_embeds": prompt_embeds,
                "region_masks": dilated_gt_masks,
                "n_prefix": 1,  # [SOS]  # "An image of"
                "gt_labels": gt_labels,  # class names
            }

        if self.debug_mode:
            ori_seg_mask = np.zeros((height, width), dtype=np.uint8)
            for mask in gt_masks:
                binary_mask = np.where(mask.cpu().numpy().squeeze() > 0.0, 1, 0).astype(
                    np.uint8
                )
                ori_seg_mask = np.bitwise_or(ori_seg_mask, binary_mask)

            dilated_seg_mask = np.zeros((512, 512), dtype=np.uint8)
            for mask in dilated_gt_masks:
                binary_mask = np.where(mask.cpu().numpy().squeeze() > 0.0, 1, 0).astype(
                    np.uint8
                )
                dilated_seg_mask = np.bitwise_or(dilated_seg_mask, binary_mask)

            import matplotlib.pyplot as plt

            plt.figure()
            plt.imshow(np.array(ori_seg_mask * 255).astype(np.uint8), cmap="gray")
            plt.savefig(f"{self.debug_log_dir}/origin_seg.png")

            plt.figure()
            plt.imshow(np.array(dilated_seg_mask * 255).astype(np.uint8), cmap="gray")
            plt.savefig(f"{self.debug_log_dir}/origin_dilated_seg_mask.png")

        self.gt_masks = gt_masks

        return prompt_embeds, control_image

    @torch.no_grad()
    def compute_rectification_mask(self, x_0, step_i=0):
        gt_bboxes, gt_labels, gt_masks = self.gt_bboxes, self.gt_labels, self.gt_masks
        height, width = x_0.size
        device = self.dino_model.device

        inputs = self.dino_processor(images=x_0, return_tensors="pt")
        inputs.update(self.perception_text_feats)
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                if v.dtype in [torch.float32, torch.float16]:
                    inputs[k] = v.to(dtype=torch.float32, device=device)
                else:
                    inputs[k] = v.to(device)
        outputs = self.dino_model(**inputs)

        det_results = self.dino_processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=0.4,
            text_threshold=0.3,
            target_sizes=[x_0.size[::-1]],
        )[0]

        det_results["scores"], det_results["boxes"] = (
            det_results["scores"].cpu().numpy(),
            det_results["boxes"].cpu().numpy(),
        )

        if self.debug_mode:
            plot_results(
                x_0,
                det_results["scores"],
                det_results["labels"],
                det_results["boxes"],
                f"{self.debug_log_dir}/step_{step_i}_det.png",
            )

        _, _, FP_bboxes, FP_labels, FN_bboxes, FN_labels, FN_indices = get_bboxes_union(
            det_results, gt_bboxes, gt_labels
        )

        seg_masks = []
        if len(FP_bboxes) > 0:
            self.sam_predictor.set_image(np.array(x_0))
            transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(
                torch.tensor(FP_bboxes, device=device), (height, width)
            ).to(device)
            FP_masks, _, _ = self.sam_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes.to(device),
                multimask_output=False,
            )
        else:
            FP_masks = []

        seg_masks.extend(FP_masks)
        for fn_index in FN_indices:
            seg_masks.append(gt_masks[fn_index])

        edit_mask = np.zeros((height, width), dtype=np.uint8)
        for mask in seg_masks:
            binary_mask = np.where(mask.cpu().numpy().squeeze() > 0.0, 1, 0).astype(
                np.uint8
            )
            edit_mask = np.bitwise_or(edit_mask, binary_mask)

        edit_mask = cv2.dilate(edit_mask, np.ones((7, 7), np.uint8))
        edit_mask = torch.tensor(edit_mask)
        edit_mask = (
            F.interpolate(
                edit_mask.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="nearest"
            )
            .squeeze(0)
            .squeeze(0)
        )

        if self.debug_mode:
            import matplotlib.pyplot as plt

            plt.figure()
            plt.imshow(np.array(edit_mask * 255).astype(np.uint8), cmap="gray")
            plt.savefig(f"{self.debug_log_dir}/step_{step_i}_edit_mask.png")

        mask = edit_mask.unsqueeze(0)
        # cast mask to the same type as latents etc
        mask = mask.to(device=device)
        mask = mask.unsqueeze(1)  # fit shape
        return mask

    def region_guided_rectification(self, latents, original_latents, mask):
        mask = mask.to(dtype=latents.dtype, device=latents.device)
        latents = original_latents * mask + latents * (1 - mask)
        return latents

    def apply_region_aligned_cross_attention(self, module):
        # module should be unet or controlnet branch
        if self.enable_region_aligned_cross_attention:
            module.set_attn_processor(CrossRegionAttnProcessor())
