import torchvision
import torch
import os

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from matplotlib.cm import ScalarMappable
from tqdm import tqdm

def bd_max_iou(predictions, score_threshold, device, multi_trigger=False):
    """
    Computes the best IOU for each sample across 2 categories:
      (1) best_org  = IOU with original label
      (2) best_target = IOU with target label
    Returns a tensor of shape [N, 3].
    """
    results = []
    for pred in predictions:

        pred_boxes = torch.tensor(pred["pred_boxes"], device=device)  # shape [num_preds, 4]

        # If there are no predictions, skip this sample
        if pred_boxes.numel() == 0:
            results.append((0.0, 0.0))
            continue

        pred_labels = torch.tensor(pred["pred_labels"], device=device)
        pred_scores = torch.tensor(pred["pred_scores"], device=device)

        gt_boxs = torch.tensor(pred["gt_boxes"], device=device)
        org_labels = torch.tensor(pred["gt_labels"], device=device)
        target_labels = torch.tensor(pred["gt_target_ids"], device=device)
        poison_masks = torch.tensor(pred["gt_poison_masks"], device=device)

        # Filter predictions based on the score threshold
        if score_threshold is not None:
            keep = pred_scores >= score_threshold
            pred_boxes = pred_boxes[keep]
            pred_labels = pred_labels[keep]
            pred_scores = pred_scores[keep]

        # Find the org_label and target_label where poison_masks is 1
        poison_index = torch.where(poison_masks == 1)[0]

        if poison_index.numel() != 1 and not multi_trigger:
            print(pred["gt_poison_masks"])
            print(pred["gt_labels"])
            print(pred["gt_target_ids"])
            print(multi_trigger)
            raise ValueError("There should be exactly one poison mask per sample.")
        
        for i in range(poison_index.numel()):

            index = poison_index[i]
            org_label = org_labels[index].item()
            target_label = target_labels[index].item()
            gt_box = gt_boxs[index].unsqueeze(0)  # Ensure gt_box is of shape [1, 4]

            try:
                iou_values = torchvision.ops.box_iou(pred_boxes, gt_box)  # shape [num_preds, 1]
            except Exception as e:
                print(f'Pred boxes: {pred_boxes.shape}, GT box: {gt_box.shape}')
                raise e
            
            best_org, best_target = 0.0, 0.0

            for iou, label in zip(iou_values, pred_labels):
                if label == org_label and iou > best_org:
                    best_org = iou
                if label == target_label and iou > best_target:
                    best_target = iou
            
            results.append((best_org, best_target))

    # Delete all of the created tensors to free memory
    if device.type == "cuda":
        torch.cuda.empty_cache()

    # Get all the area values
    results = torch.tensor(results, device=device)
    return results

def plot_iou_scores(predictions, device, save_path, current_epoch):

    iou_results = []

    for pred in predictions:

        pred_boxes = torch.tensor(pred["pred_boxes"], device=device)  # shape [num_preds, 4]

        if pred_boxes.numel() == 0:
            continue

        pred_labels = torch.tensor(pred["pred_labels"], device=device)
        pred_scores = torch.tensor(pred["pred_scores"], device=device)

        gt_boxs = torch.tensor(pred["gt_boxes"], device=device)
        org_labels = torch.tensor(pred["gt_labels"], device=device)
        target_labels = torch.tensor(pred["gt_target_ids"], device=device)
        poison_masks = torch.tensor(pred["gt_poison_masks"], device=device)

        # Find the org_label and target_label where poison_masks is 1
        poison_index = torch.where(poison_masks == 1)[0]

        for i in range(poison_index.numel()):

            index = poison_index[i]
            org_label = org_labels[index].item()
            target_label = target_labels[index].item()
            gt_box = gt_boxs[index].unsqueeze(0)  # Ensure gt_box is of shape [1, 4]

            iou_values = torchvision.ops.box_iou(pred_boxes, gt_box)  # shape [num_preds, 1]

            for i in range(len(pred_boxes)):
                iou = iou_values[i]
                label = pred_labels[i]
                score = pred_scores[i]

                if label == org_label:
                    iou_results.append((iou, score, 0))
                elif label == target_label:
                    iou_results.append((iou, score, 1))

    # If no results were collected, return early
    if not iou_results:
        print("No IOU results to plot.")
        return


    iou_results = torch.tensor(iou_results)

    # Check if gt_target has added values
    # if the 2nd column doesnt contain any values of 1, then we can skip plotting the second graph
    if not torch.any(iou_results[:, 2] == 1):
        num_plots = 1
    else:
        num_plots = 2

    fig, axs = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5))
    if num_plots == 1:
        axs = [axs]  # make it iterable

    for i in range(num_plots):
        ax = axs[i]

        # filter
        mask = iou_results[:, 2] == i
        iou   = iou_results[mask, 0]
        score = iou_results[mask, 1]

        # --- 1) estimate density via Gaussian KDE ---
        xy  = np.vstack([iou, score])

        # If their are not enough points, skip the KDE
        if xy.shape[1] < 2:
            print(f"Not enough points for KDE in plot {i}. Skipping.")
            continue

        try:
            kde = gaussian_kde(xy)

            # build grid over [0,1]×[0,1]
            x_lin = np.linspace(0, 1, 200)
            y_lin = np.linspace(0, 1, 200)
            X, Y = np.meshgrid(x_lin, y_lin)
            Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
        
            # --- 2) draw filled contours or just contour lines ---
            # contour lines
            # draw colored contour lines only
            cs = ax.contour(
                X, Y, Z,
                levels=10,
                cmap='viridis',
                linewidths=2
            )

            # build a continuous mappable from the same colormap & norm
            sm = ScalarMappable(cmap=cs.cmap, norm=cs.norm)
            sm.set_array([])   # dummy; the colorbar only needs cmap & norm

            # attach a *continuous* colorbar
            cbar = fig.colorbar(sm, ax=ax)
            cbar.set_label('Density')
        except Exception as e:
            print(f"Error in KDE or contour plotting for plot {i}: {e}")

        # --- 3) scatter original points underneath ---
        ax.scatter(iou, score, alpha=0.2, color='grey')

        # labels, titles, limits
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_xlabel("IOU")
        ax.set_ylabel("Score")
        if i == 0:
            ax.set_title("IOU vs Score (Original Label)")
        else:
            ax.set_title("IOU vs Score (Target Label)")
            
    # save
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"iou_scores_epoch_{current_epoch}.png"))
    plt.close() 

def get_ra_asr_matrix_oda(results, conf_th=np.arange(0.00, 1.01, 0.05), iou_th=np.arange(0.00, 1.01, 0.05)):
    """
    results: list of dicts exactly as you specified
    returns: RA_grid, ASR_grid  (shape [len(iou_th), len(conf_th)])
    """
    # ---- flatten all poisoned GT boxes once -------------------------------
    poison_gt_boxes, poison_gt_labels = [], []
    per_img_indices = []   # start idx of each image's poisoned subset
    for d in results:
        pm = torch.tensor(d["gt_poison_masks"]) == 1
        per_img_indices.append(len(poison_gt_boxes))               # record offset
        if pm.any():
            poison_gt_boxes.append(torch.tensor(d["gt_boxes"])[pm])
            poison_gt_labels.append(torch.tensor(d["gt_labels"])[pm])
        else:                                                      # keep dims aligned
            poison_gt_boxes.append(torch.empty((0,4)))
            poison_gt_labels.append(torch.empty((0,), dtype=torch.long))
    total_poison = int(sum(x.shape[0] for x in poison_gt_boxes))

    # early‑out guard
    if total_poison == 0:
        raise ValueError("No poisoned boxes found – nothing to evaluate.")

    # convert to tensors-of-lists for faster indexing
    conf_th = torch.tensor(conf_th)
    iou_th  = torch.tensor(iou_th)
    RA  = torch.zeros(iou_th.numel(), conf_th.numel())
    ASR = torch.zeros_like(RA)

    # ---- per‑image loop kept small; inner work batched --------------------
    for img_idx, d in enumerate(results):
        
        if poison_gt_boxes[img_idx].numel() == 0:
            continue

        pred_boxes = torch.tensor(d["pred_boxes"])
        pred_labels = torch.tensor(d["pred_labels"])
        pred_scores = torch.tensor(d["pred_scores"])
        box_gt = poison_gt_boxes[img_idx]
        label_gt = poison_gt_labels[img_idx]

        if pred_boxes.numel() == 0:
            ASR[:, :] += box_gt.shape[0]          # broadcast to every IoU × conf cell
            continue

        # pre‑compute IoU between *all* predictions and poisoned GTs once
        try:
            iou_mat = torchvision.ops.box_iou(box_gt, pred_boxes)  # (P, F)
        except Exception as e:
            print(f'Pred boxes: {pred_boxes.shape}, GT boxes: {box_gt.shape}')
            raise e

        # outer loops are thresholds – cheap
        for j, ct in enumerate(conf_th):
            keep = pred_scores >= ct
            if keep.sum() == 0:
                # no predictions survive – all poisoned boxes count toward ASR
                ASR[:, j] += box_gt.shape[0]
                continue

            # filter predictions by confidence
            labels_filt = pred_labels[keep]
            iou_filt = iou_mat[:, keep]                         # (P, F_keep)

            # for each IoU threshold, decide hit/miss
            for i, it in enumerate(iou_th):
                # a hit occurs if any filtered prediction both matches label
                # *and* exceeds IoU
                label_match = labels_filt[None, :] == label_gt[:, None]   # (P,F)
                hit = (iou_filt > it) & label_match
                detected = hit.any(dim=1)
                RA[i, j]  += detected.sum()
                ASR[i, j] += (~detected).sum()

    # normalise to rates
    RA  = RA / total_poison
    ASR = ASR / total_poison
    return RA.numpy(), ASR.numpy(), conf_th.numpy(), iou_th.numpy()

def get_ra_asr_matrix_rma(results, conf_th=np.arange(0.00, 1.01, 0.05), iou_th=np.arange(0.00, 1.01, 0.05)):
    """
    results: list of dicts exactly as you specified
    returns: RA_grid, ASR_grid  (shape [len(iou_th), len(conf_th)])
    """
    # ---- flatten all poisoned GT boxes once -------------------------------
    poison_gt_boxes, poison_gt_labels, poison_target_labels = [], [], []
    per_img_indices = []   # start idx of each image's poisoned subset
    for d in results:
        pm = torch.tensor(d["gt_poison_masks"]) == 1
        per_img_indices.append(len(poison_gt_boxes))               # record offset
        if pm.any():
            poison_gt_boxes.append(torch.tensor(d["gt_boxes"])[pm])
            poison_gt_labels.append(torch.tensor(d["gt_labels"])[pm])
            poison_target_labels.append(torch.tensor(d["gt_target_ids"])[pm])
        else:                                                      # keep dims aligned
            poison_gt_boxes.append(torch.empty((0,4)))
            poison_gt_labels.append(torch.empty((0,), dtype=torch.long))
            poison_target_labels.append(torch.empty((0,), dtype=torch.long))

    total_poison = int(sum(x.shape[0] for x in poison_gt_boxes))

    # early‑out guard
    if total_poison == 0:
        raise ValueError("No poisoned boxes found – nothing to evaluate.")

    # convert to tensors-of-lists for faster indexing
    conf_th = torch.tensor(conf_th)
    iou_th  = torch.tensor(iou_th)
    RA  = torch.zeros(iou_th.numel(), conf_th.numel())
    ASR = torch.zeros_like(RA)

    # ---- per‑image loop kept small; inner work batched --------------------
    for img_idx, d in enumerate(results):
        
        if poison_gt_boxes[img_idx].numel() == 0:
            continue

        pred_boxes = torch.tensor(d["pred_boxes"])
        pred_labels = torch.tensor(d["pred_labels"])
        pred_scores = torch.tensor(d["pred_scores"])
        box_gt = poison_gt_boxes[img_idx]
        label_gt = poison_gt_labels[img_idx]
        target_gt = poison_target_labels[img_idx]

        # If any target_gt are less than 1 raise an error
        if (target_gt < 1).any():
            raise ValueError(f"Target labels should be >= 1, found {target_gt[target_gt < 1]} in image {img_idx}")

        if pred_boxes.numel() == 0:
            continue

        # pre‑compute IoU between *all* predictions and poisoned GTs once
        try:
            iou_mat = torchvision.ops.box_iou(box_gt, pred_boxes)  # (P, F)
        except Exception as e:
            print(f'Pred boxes: {pred_boxes.shape}, GT boxes: {box_gt.shape}')
            raise e

        # outer loops are thresholds – cheap
        for j, ct in enumerate(conf_th):
            keep = pred_scores >= ct
            if keep.sum() == 0:
                continue

            # filter predictions by confidence
            labels_filt = pred_labels[keep]
            iou_filt = iou_mat[:, keep]                         # (P, F_keep)

            # for each IoU threshold, decide hit/miss
            for i, it in enumerate(iou_th):
                
                

                # Find matches for target label
                # ASR increases with target label matches
                target_match = labels_filt[None, :] == target_gt[:, None]   #
                hit = (iou_filt > it) & target_match
                detected = hit.any(dim=1)
                ASR[i, j] += detected.sum()

                # Find matches for original label
                # RA increases with original label matches
                label_match = labels_filt[None, :] == label_gt[:, None]   #
                hit = (iou_filt > it) & label_match
                detected = hit.any(dim=1)
                RA[i, j]  += detected.sum()

    # normalise to rates
    RA  = RA / total_poison
    ASR = ASR / total_poison
    return RA.numpy(), ASR.numpy(), conf_th.numpy(), iou_th.numpy()

def plot_asr_ra_heat(ra_matrix, asr_matrix, x_ticks, y_ticks, save_path, current_epoch, title=None):

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    if title is None:
        title = "RA and ASR Heatmaps"   
    
    for i, (ax, mat, title) in enumerate(zip(axs, [ra_matrix, asr_matrix], ["RA Heatmap", "ASR Heatmap"])):
        if i == 0:
            ax.set_title("RA Heatmap")
        else:
            ax.set_title("ASR Heatmap")

        # Plot the heatmap
        im = ax.imshow(mat, origin='lower', aspect='auto',
                       extent=[x_ticks[0], x_ticks[-1], y_ticks[0], y_ticks[-1]],
                       vmin=0, vmax=1)
        ax.set_xlabel('Confidence Threshold')
        ax.set_ylabel('IoU Threshold')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    plt.tight_layout()  
    plt.savefig(os.path.join(save_path, f'asr_ra_heatmap_epoch_{current_epoch}.png'))
    plt.close()