import os
import glob
import random
import cv2
import numpy as np
import torch
import albumentations as A

def get_image_grid(image, annotation, grid_size, current_format='xywh'):
    """
    Divide an image into an N×N grid and mark any cell touched by any corner of each bbox.
    """
    # Determine H, W
    if hasattr(image, 'shape') and len(image.shape) == 3:
        img_h, img_w = image.shape[1], image.shape[2]
    else:
        img_h, img_w = image.shape[0], image.shape[1]

    grid_h = img_h // grid_size
    grid_w = img_w // grid_size

    image_grid = np.zeros((grid_size, grid_size), dtype=bool)
    for bbox in annotation['bbox']:
        if current_format == 'xywh':
            x1, y1 = bbox[0], bbox[1]
            x2, y2 = bbox[0] + bbox[2], bbox[1] + bbox[3]
        else:
            x1, y1, x2, y2 = bbox
        corners = [(x1, y1), (x2, y1), (x1, y2), (x2, y2)]
        corner_cells = set()
        for x, y in corners:
            gx = min(int(x // grid_w), grid_size - 1)
            gy = min(int(y // grid_h), grid_size - 1)
            corner_cells.add((gy, gx))
        for gy, gx in corner_cells:
            image_grid[gy, gx] = True
    return image_grid


def build_aug_pipeline(p: float = 0.4) -> A.Compose:
    """
    Single Albumentations pipeline for all geometric & photometric transforms.
    """
    return A.Compose([
        A.Rotate(limit=7.16, border_mode=cv2.BORDER_REPLICATE, p=p),
        A.Perspective(scale=(0.05, 0.1), keep_size=True, border_mode=cv2.BORDER_REPLICATE, p=p),
        A.RandomShadow(shadow_roi=(0, 0, 1, 1), num_shadows_limit=(1, 4), p=p),
        A.GaussianBlur(blur_limit=(3, 7), p=p),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=p),
        A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=p),
    ])

aug_pipeline = build_aug_pipeline(p=0.4)


def augment_sign(img: np.ndarray) -> np.ndarray:
    """
    Apply all photometric & geometric transforms in one go.
    Expects img in RGBA (H, W, 4).
    """
    rgb = img[..., :3]
    alpha = img[..., 3]
    augmented = aug_pipeline(image=rgb)['image']
    return np.dstack([augmented, alpha])


def alpha_blend(bg: np.ndarray, fg_rgba: np.ndarray, y_off: int, x_off: int) -> None:
    """
    In-place alpha blend of fg_rgba onto bg at (y_off, x_off).
    bg can be H×W×3 or H×W×4; only RGB channels are modified.
    """
    h, w = fg_rgba.shape[:2]
    # Extract background RGB slice
    if bg.shape[2] == 3:
        bg_rgb = bg[y_off:y_off+h, x_off:x_off+w]
    elif bg.shape[2] == 4:
        bg_rgb = bg[y_off:y_off+h, x_off:x_off+w, :3]
    else:
        raise ValueError(f"Unsupported bg channels: {bg.shape[2]}")

    # Convert to float for blending
    bg_patch = bg_rgb.astype(np.float32)
    fg = fg_rgba[..., :3].astype(np.float32)
    alpha = fg_rgba[..., 3:4].astype(np.float32) / 255.0

    # Blend RGB channels
    blended_rgb = bg_patch * (1 - alpha) + fg * alpha
    blended_rgb = blended_rgb.astype(np.uint8)

    # Write back to bg
    if bg.shape[2] == 3:
        bg[y_off:y_off+h, x_off:x_off+w] = blended_rgb
    else:
        bg[y_off:y_off+h, x_off:x_off+w, :3] = blended_rgb


def poison_sign(
    args: dict,
    sign_rgba: np.ndarray,
    trigger_rgba: np.ndarray,
):
    # 1) Augment
    out = augment_sign(sign_rgba)
    
    # 2) Maybe overlay trigger
    if random.random() >= args['prob_add_trigger']:
        return sign_rgba, False

    H = out.shape[0]
    # resize trigger
    h0, w0 = trigger_rgba.shape[:2]
    tgt_h = int(args['trigger_ratio'] * H)
    scale = tgt_h / h0
    tgt_w = int(w0 * scale)
    interp = cv2.INTER_AREA if tgt_h < h0 else cv2.INTER_LINEAR
    trig = cv2.resize(trigger_rgba, (tgt_w, tgt_h), interpolation=interp)
    # random brightness
    alpha_ch = trig[..., 3]
    factor = random.uniform(args['min_light']/100, args['max_light']/100)
    rgb = cv2.convertScaleAbs(trig[..., :3], alpha=factor, beta=0)
    trigger = np.dstack([rgb, alpha_ch])

    # determine positions
    margin = int(0.05 * H)
    cx = (out.shape[1] - tgt_w) // 2
    pos = args['trigger_position']
    if pos == 'high': locs = [(margin, cx)]
    elif pos == 'low': locs = [(H - tgt_h - margin, cx)]
    elif pos == 'both': locs = [(margin, cx), (H - tgt_h - margin, cx)]
    elif pos == 'center': locs = [((H - tgt_h)//2, cx)]
    elif pos == 'random': locs = [
        (random.randint(margin, H - tgt_h - margin), random.randint(0, out.shape[1] - tgt_w))
    ]
        
    else: raise ValueError(f"Invalid position: {pos!r}")

    for y_off, x_off in locs:
        alpha_blend(out, trigger, y_off, x_off)

    return out, True


def add_signs_and_triggers(
    args: dict,
    image: torch.Tensor,
    grid: np.ndarray,
    mapping: list,
):
    # Convert tensor to H×W×3 uint8 image
    img_np = (
        image.permute(1, 2, 0)
             .mul(255).clamp(0, 255)
             .byte().cpu().numpy()
    )
    out = img_np.copy()
    h, w = out.shape[:2]
    G = args['grid_size']
    cell_h, cell_w = h // G, w // G

    # Preload sign templates
    signs_cache = {}
    for m in mapping:
        id_ = m['id']
        folder = os.path.join(args['sign_path'], str(id_))
        paths = glob.glob(os.path.join(folder, '*.png'))
        imgs = []
        for p in paths:
            s = cv2.imread(p, cv2.IMREAD_UNCHANGED)
            if s is None: continue
            if s.ndim == 3:
                s = cv2.cvtColor(s, cv2.COLOR_BGR2BGRA)
            imgs.append(s[..., [2, 1, 0, 3]])
        signs_cache[id_] = imgs

    # Load trigger once
    trg = cv2.imread(args['trigger_path'], cv2.IMREAD_UNCHANGED)
    if trg.ndim == 3:
        trg = cv2.cvtColor(trg, cv2.COLOR_BGR2BGRA)
    trigger_rgba = trg[..., [2, 1, 0, 3]]

    new_anns = []
    for r in range(G):
        for c in range(G):
            if grid[r, c]:
                continue
            y0, x0 = r * cell_h, c * cell_w
            m = random.choice(mapping)
            sign_orig = random.choice(signs_cache[m['id']])
            sign_rgba, is_poisoned = poison_sign(args, sign_orig, trigger_rgba)

            # Random resize
            sh, sw = sign_rgba.shape[:2]
            rh = random.randint(32, 192)
            rw = int(rh * (sw / sh))
            interp = cv2.INTER_AREA if rh < sh else cv2.INTER_LINEAR
            sign_rgba = cv2.resize(sign_rgba, (rw, rh), interpolation=interp)
            sc = min((cell_h-1)/rh, (cell_w-1)/rw, 1.0)
            if sc < 1.0:
                sign_rgba = cv2.resize(
                    sign_rgba,
                    (int(rw*sc), int(rh*sc)),
                    interpolation=cv2.INTER_AREA
                )
            nh, nw = sign_rgba.shape[:2]

            y_off = y0 + random.randrange(cell_h - nh)
            x_off = x0 + random.randrange(cell_w - nw)
            alpha_blend(out, sign_rgba, y_off, x_off)

            new_anns.append({
                'class_id': m['id'],
                'class_label': m['class_label'],
                'meta_label': m['meta_label'],
                'bbox': [x_off, y_off, x_off+nw, y_off+nh],
                'poison_mask': is_poisoned,
            })

    image_out = torch.from_numpy(out).permute(2, 0, 1).float().div(255)
    return image_out, new_anns