import torch
import torch.nn.functional as F
from typing import List, Dict, Tuple, Any


# -----------------------------------------------------------
#  GPU‑only YOLO augmentations (Mosaic + MixUp + Multi‑scale)
# -----------------------------------------------------------

def _random_dataset_sample(dataset, device):
    idx = torch.randint(0, len(dataset), ()).item()
    img, tgt, *_ = dataset[idx]  # img: (C,H,W), tgt["boxes"]: cx,cy,w,h (likely pixels)

    img = img.to(device, non_blocking=True)

    # Ensure boxes are normalized cxcywh in [0,1]
    boxes = tgt["boxes"].to(device, non_blocking=True)
    H, W = img.shape[1:]
    # If any coord looks like pixels, normalize it
    if boxes.numel() and (boxes[:, :2].max() > 1 or boxes[:, 2:].max() > 1):
        scale = torch.tensor([W, H, W, H], device=device, dtype=boxes.dtype)
        boxes = boxes / scale

    tgt_norm = {
        "boxes": boxes,  # normalized cxcywh
        "labels": tgt["labels"].to(device, non_blocking=True),
        "poison_masks": tgt["poison_masks"].to(device, non_blocking=True),
        "target_labels": tgt["target_labels"].to(device, non_blocking=True),
    }
    return img, tgt_norm


def _filter_and_normalise_boxes(
    boxes_px: torch.Tensor,
    labels: torch.Tensor,
    poison_masks: torch.Tensor,
    target_labels: torch.Tensor,
    x1_crop: int,
    y1_crop: int,
    crop: int,
    min_pixels: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Shift → clip → filter → normalise *and* synchronise every target tensor.

    Args:
        boxes_px:        (N,4) cx,cy,w,h **in pixels**.
        labels:          (N,)
        poison_masks:    (N,)
        target_labels:   (N,)
        x1_crop, y1_crop:   top‑left of the s×s window cut from 2s×2s canvas.
        crop:              side length *s* of the final window.

    Returns:
        (boxes_norm, labels_f, poison_masks_f, target_labels_f) – each length M ≤ N
    """
    if boxes_px.numel() == 0:
        empty = torch.zeros((0,), dtype=torch.long, device=boxes_px.device)
        return (
            boxes_px,  # already empty
            empty,  # labels
            empty,  # poison_masks
            empty,  # target_labels
        )

    # Shift by crop offset (still pixel units)
    boxes_shifted = boxes_px.clone()
    boxes_shifted[:, 0] -= x1_crop
    boxes_shifted[:, 1] -= y1_crop

    # Convert to corner form
    cx, cy, w, h = boxes_shifted.T
    x_min, y_min = cx - w / 2, cy - h / 2
    x_max, y_max = cx + w / 2, cy + h / 2

    # Clip to crop window
    x_min_cl = x_min.clamp(0, crop)
    y_min_cl = y_min.clamp(0, crop)
    x_max_cl = x_max.clamp(0, crop)
    y_max_cl = y_max.clamp(0, crop)

    w_cl = x_max_cl - x_min_cl
    h_cl = y_max_cl - y_min_cl

    keep = (w_cl >= min_pixels) & (h_cl >= min_pixels)
    if not keep.any():
        empty = torch.zeros((0,), dtype=torch.long, device=boxes_px.device)
        return (
            torch.zeros((0, 4), device=boxes_px.device),
            empty,
            empty,
            empty,
        )

    # Re‑compute centres / sizes from clipped corners and normalise
    cx_n = (x_min_cl + x_max_cl) / 2.0 / crop
    cy_n = (y_min_cl + y_max_cl) / 2.0 / crop
    w_n = w_cl / crop
    h_n = h_cl / crop

    boxes_norm = torch.stack([cx_n, cy_n, w_n, h_n], dim=1)[keep]
    boxes_norm[:, 2:].clamp_(min=1e-3, max=1.0)

    # Synchronise other target tensors
    labels_f = labels[keep]
    poison_masks_f = poison_masks[keep]
    target_labels_f = target_labels[keep]

    return boxes_norm, labels_f, poison_masks_f, target_labels_f


def yolo_gpu_augment(
    images: List[torch.Tensor],
    targets: List[Dict[str, torch.Tensor]],
    dataset: Any,
    device: torch.device,
    img_size: int = 640,
    max_stride: int = 64,
    mosaic_prob: float = 0.5,
    mixup_prob: float = 0.15,
):
    """GPU‑side Mosaic, MixUp and multi‑scale augmentation for YOLO‑style models.

    * Expects **(cx,cy,w,h)** boxes normalised to [0,1] on entry.
    * Keeps batch order untouched – safe for DDP.
    """

    B = len(images)
    s = img_size  # alias

    # -------------------------------------------------------------
    # 1) MOSAIC (per‑sample)
    # -------------------------------------------------------------
    for b in range(B):
        if torch.rand(1, device=device) >= mosaic_prob:
            continue

        base_img, base_tgt = images[b], targets[b]

        # Gather 3 extra random samples
        mosaic_imgs = [base_img]
        mosaic_tgts = [base_tgt]
        for _ in range(3):
            im_i, tgt_i = _random_dataset_sample(dataset, device)
            mosaic_imgs.append(im_i)
            mosaic_tgts.append(tgt_i)

        # 2s×2s canvas
        canvas = torch.full((3, 2 * s, 2 * s), 114.0 / 255.0, dtype=base_img.dtype, device=device)
        yc, xc = [int(torch.empty(1, device=device).uniform_(0.5 * s, 1.5 * s)) for _ in range(2)]

        merged_boxes, merged_lbls, merged_pm, merged_tl = [], [], [], []

        for i, (im, tgt) in enumerate(zip(mosaic_imgs, mosaic_tgts)):
            _, h, w = im.shape
            scale = torch.empty(1, device=device).uniform_(0.4, 1.0).item()
            nh, nw = int(h * scale), int(w * scale)
            im_rs = F.interpolate(im.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False)[0]

            # Paste coordinates
            if i == 0:
                x1a, y1a = max(xc - nw, 0), max(yc - nh, 0)
            elif i == 1:
                x1a, y1a = xc, max(yc - nh, 0)
            elif i == 2:
                x1a, y1a = max(xc - nw, 0), yc
            else:
                x1a, y1a = xc, yc
            x2a, y2a = min(x1a + nw, 2 * s), min(y1a + nh, 2 * s)

            canvas[:, y1a:y2a, x1a:x2a] = im_rs[:, : y2a - y1a, : x2a - x1a]

            if tgt["boxes"].numel():
                boxes_px = tgt["boxes"].clone()
                boxes_px[:, [0, 2]] *= nw
                boxes_px[:, [1, 3]] *= nh
                boxes_px[:, 0] += x1a
                boxes_px[:, 1] += y1a

                merged_boxes.append(boxes_px)
                merged_lbls.append(tgt["labels"])
                merged_pm.append(tgt["poison_masks"])
                merged_tl.append(tgt["target_labels"])

        # Crop central s×s patch
        x1_crop, y1_crop = xc - s // 2, yc - s // 2
        cropped = canvas[:, y1_crop : y1_crop + s, x1_crop : x1_crop + s]

        # Target post‑processing
        if merged_boxes:
            boxes_px_all = torch.cat(merged_boxes, 0)
            labels_all = torch.cat(merged_lbls, 0)
            pm_all = torch.cat(merged_pm, 0)
            tl_all = torch.cat(merged_tl, 0)

            boxes_n, lbls_f, pm_f, tl_f = _filter_and_normalise_boxes(
                boxes_px_all,
                labels_all,
                pm_all,
                tl_all,
                x1_crop,
                y1_crop,
                s,
            )

            targets[b] = {
                "boxes": boxes_n,
                "labels": lbls_f,
                "poison_masks": pm_f,
                "target_labels": tl_f,
            }
        else:
            # no boxes at all from the four images
            targets[b] = {
                "boxes": torch.zeros((0, 4), device=device),
                "labels": torch.zeros((0,), dtype=torch.long, device=device),
                "poison_masks": torch.zeros((0,), device=device),
                "target_labels": torch.zeros((0,), dtype=torch.long, device=device),
            }

        images[b] = cropped

    # ------------------------------------------------------------------
    # 2. MIXUP (per‑sample, *after* Mosaic but *before* batch resize)
    # ------------------------------------------------------------------
    beta = torch.distributions.Beta(32.0, 32.0)
    for b in range(B):
        if torch.rand(1, device=device) >= mixup_prob:
            continue

        mix_img, mix_tgt = _random_dataset_sample(dataset, device)
        h_dest, w_dest = images[b].shape[1:]

        # If sizes differ, resize image; boxes are normalized so no change needed
        if mix_img.shape[1:] != (h_dest, w_dest):
            mix_img = F.interpolate(mix_img.unsqueeze(0), size=(h_dest, w_dest),
                                    mode="bilinear", align_corners=False)[0]

        lam = float(beta.sample(()))
        images[b] = images[b] * lam + mix_img * (1.0 - lam)

        # mix_tgt is already normalized (from our patched _random_dataset_sample)
        for key in targets[b]:
            targets[b][key] = torch.cat((targets[b][key], mix_tgt[key]), dim=0)


    # ------------------------------------------------------------------
    # 3. MULTI‑SCALE (batch‑level)
    # ------------------------------------------------------------------
    sf = torch.empty(1).uniform_(0.5, 1.5).item()
    sz = int(img_size * sf / max_stride + 0.5) * max_stride

    if sz != img_size:
        for b in range(B):
            images[b] = F.interpolate(images[b].unsqueeze(0), size=sz, mode="bilinear", align_corners=False)[0]

    return images, targets