from __future__ import annotations
import random
from typing import Any, Dict, List, Tuple
from collections import deque
from PIL import Image, ImageDraw, ImageFilter, ImageChops, ImageMath

from ...base import Task
from ...registry import register_task
from ...config import OUT_CELL, MAX_BUILD_RETRIES, OPT_HASH_MIN_BITS, OPT_UNIQUENESS_MIN, SS_CELL
from .common import diff_frac, flip_h, flip_v, rot, rot90, rot180, rot270
from ...utils.specs import _prefer_asym_mode
from ...utils.rng import choice_weighted
from ...utils.drawing import (
    tight_crop_rgba,
    load_font,
    labels_default,
)

# ─────────────────────────────────────────────────────────────────────────────
# All 17 wallpaper groups
WALLPAPER_WEIGHTS: Dict[str, float] = {
    "pm": 1.0, "pg": 1.0, "p2": 1.0, "pmm": 1.0,
    "pgg": 1.0, "cmm": 1.0, "p4": 1.0, "pmg": 1.0,
    "p1": 1.0, "cm": 1.0, "p4m": 1.0, "p4g": 1.0,
    "p3": 1.0, "p3m1": 1.0, "p31m": 1.0, "p6": 1.0, "p6m": 1.0,
}
WALLPAPER_IUC = {k: k for k in WALLPAPER_WEIGHTS.keys()}


MOTIF_WEIGHTS = {
    "icons": 5,
    "arc": 0.75,
    "single_arrow": 1.0,
    "clock": 0.5,
    "crescent": 1.0,
    "fractal": 0.25,
    "glyph": 1.0,
    "keyhole": 0.25,
    "pictogram": 0.5,
    "polygon": 0.25,
    "polyhex": 0.125,
    "polyiamond": 0.125,
    "polyline": 0.125,
    "polyomino": 0.125,
    "segment": 1.0,
}


PROMPT_TEMPLATES = [
    "The figure shows a 2×2 grid of patches (a–d). Exactly three share the same 2D symmetry; which option is different?",
    "In the 2×2 grid, three tiles belong to the same wallpaper group; which is the odd tile (a–d)?",
    "In the 2×2 grid, three patches follow the same reflections/rotations/glides. Which option (a–d) is different?",
    "Compare the symmetry relation repeated within each patch. Which tile in the 2×2 grid (a–d) comes from a different wallpaper group?",
    "Three panels follow one wallpaper rule, while one uses another. Which panel (a–d) in the 2×2 grid is different?",
    "Look at how the motif repeats inside each tile. Which option (a–d) in the 2×2 grid breaks the shared symmetry?",
    "In the 2×2 grid of patches, one tile is generated by a different combination of flips/rotations/glides. Which one (a–d) is it?",
    "Only one of the four tiles does not share the others’ 2D symmetry. Which option (a–d) in the 2×2 grid is the odd one?",
    "Three options belong to the same wallpaper symmetry. Which tile in the 2×2 grid (a–d) belongs to a different group?",
    "Which patch in the 2×2 grid (a–d) does not match the symmetry class of the other three?",
]


# Visual format
GRID_SIZE_RANGE = (4, 6)

# Geometry (pre‑fit, SS scale)
CELL_BASE_SCALE_FRAC = 0.24
MOTIF_FIT_FRAC = 0.78
GLIDE_X_FRAC = 0.22
GLIDE_Y_FRAC = 0.22
ROW_STAGGER_FRAC = 0.5

# Render scale (high‑res like frieze → crisp edges)
PATCH_TILE_SCALE: float = 2.0   # 2× SS resolution for composition; no downscale after

# Label band (same style as your seq tasks)  :contentReference[oaicite:9]{index=9}
LABEL_FRAC = 0.16
LABEL_MIN_PX = 14
LABEL_MAX_PX = 32
LABEL_PAD_FRAC = 1/38
LABEL_PAD_MIN = 4

# Gates
ADJ_MIN = float(OPT_UNIQUENESS_MIN) * 0.75
OPT_MIN_DELTA = float(OPT_UNIQUENESS_MIN)

# ─────────────────────────────────────────────────────────────────────────────
# Transforms

def _rot60(im: Image.Image)  -> Image.Image: return rot(im, 60)
def _rot120(im: Image.Image) -> Image.Image: return rot(im, 120)
def _rot240(im: Image.Image) -> Image.Image: return rot(im, 240)
def _rot300(im: Image.Image) -> Image.Image: return rot(im, 300)

# ─────────────────────────────────────────────────────────────────────────────
# Crisp scaling & alpha hygiene (ported from frieze)  :contentReference[oaicite:10]{index=10} :contentReference[oaicite:11]{index=11}

# --- crisp scaling & subtle alpha grow (same spirit as frieze).  :contentReference[oaicite:8]{index=8}
def _scale_longer_side_crisp(im: Image.Image, target: int) -> Image.Image:
    if im.mode != "RGBA": im = im.convert("RGBA")
    w, h = im.size
    if max(w, h) == 0: return im
    s = min(1.0, target / float(max(w, h)))  # never upscale
    nw, nh = max(8, int(round(w*s))), max(8, int(round(h*s)))
    if s < 0.7:
        # staged: BOX then LANCZOS
        inter = max(nw, nh)
        k = max(1, int(round(0.72 * max(w, h) / inter)))
        tmp = im.resize((w//k, h//k), Image.BOX)
        out = tmp.resize((nw, nh), Image.LANCZOS)
    else:
        out = im.resize((nw, nh), Image.LANCZOS)
    r,g,b,a = out.split()
    a = a.point(lambda t: 255 if t > 230 else (0 if t < 6 else t))
    return Image.merge("RGBA", (r,g,b,a))

def _thicken_alpha_1px(im: Image.Image) -> Image.Image:
    if im.mode != "RGBA": im = im.convert("RGBA")
    r,g,b,a = im.split()
    a2 = a.filter(ImageFilter.MaxFilter(size=3))
    return Image.merge("RGBA", (r,g,b,a2))

def _edge_connected_mask(maskL: Image.Image) -> Image.Image:
    """
    Return an L-mode mask of edge-connected 'on' pixels in maskL (expected 0/255).
    BFS over 4-neighbors. Fast and pure-PIL.
    """
    w, h = maskL.size
    src = maskL.load()
    out = Image.new("L", (w, h), 0)
    dst = out.load()

    Q = deque()
    # seed queue with edge pixels that are 'on'
    for x in range(w):
        if src[x, 0] > 0:       dst[x, 0] = 255; Q.append((x, 0))
        if src[x, h-1] > 0:     dst[x, h-1] = 255; Q.append((x, h-1))
    for y in range(h):
        if src[0, y] > 0:       dst[0, y] = 255; Q.append((0, y))
        if src[w-1, y] > 0:     dst[w-1, y] = 255; Q.append((w-1, y))

    while Q:
        x, y = Q.popleft()
        for nx, ny in ((x-1,y), (x+1,y), (x,y-1), (x,y+1)):
            if 0 <= nx < w and 0 <= ny < h and dst[nx, ny] == 0 and src[nx, ny] > 0:
                dst[nx, ny] = 255
                Q.append((nx, ny))
    return out


def _ops_list_to_dict(ops_list: List[str]) -> dict:
    d = {"flip_h": False, "flip_v": False, "rot": 0}
    for op in ops_list:
        if op == "H": d["flip_h"] = True
        elif op == "V": d["flip_v"] = True
        elif op == "R90": d["rot"] = 90
        elif op == "R180": d["rot"] = 180
        elif op == "R270": d["rot"] = 270
    return d

def _apply_ops_list(im: Image.Image, ops_list: List[str]) -> Image.Image:
    im = im.convert("RGBA")
    for op in ops_list:
        if   op == "H":    im = flip_h(im)
        elif op == "V":    im = flip_v(im)
        elif op == "R60":  im = im.rotate( 60, resample=Image.BICUBIC, expand=True)
        elif op == "R90":  im = im.rotate( 90, resample=Image.BICUBIC, expand=True)
        elif op == "R120": im = im.rotate(120, resample=Image.BICUBIC, expand=True)
        elif op == "R180": im = im.rotate(180, resample=Image.BICUBIC, expand=True)
        elif op == "R240": im = im.rotate(240, resample=Image.BICUBIC, expand=True)
        elif op == "R270": im = im.rotate(270, resample=Image.BICUBIC, expand=True)
        elif op == "R300": im = im.rotate(300, resample=Image.BICUBIC, expand=True)
    return tight_crop_rgba(im, pad=2)



def _knockout_edge_white_rgba(im: Image.Image, white_thr: int = 252, alpha_thr: int = 8) -> Image.Image:
    """
    Make edge-connected near-white (or low-alpha) background transparent,
    preserving interior white regions. Operates on RGBA.
    """
    im = im.convert("RGBA")
    r, g, b, a = im.split()

    # near-white RGB AND sufficiently opaque → candidate background
    near_white = ImageMath.eval("convert((r>t)*(g>t)*(b>t), 'L')", r=r, g=g, b=b, t=int(white_thr))
    low_alpha  = a.point(lambda t: 255 if t <= alpha_thr else 0)
    cand       = ImageMath.eval("convert(max(nw, la), 'L')", nw=near_white, la=low_alpha)

    # keep only edge-connected background candidates
    edge_bg = _edge_connected_mask(cand).filter(ImageFilter.MaxFilter(size=3))  # 1px grow to kill halos

    # alpha' = alpha * (1 - edge_bg)
    inv = ImageChops.invert(edge_bg)
    a2  = ImageChops.multiply(a, inv)  # (L * L)/255

    return Image.merge("RGBA", (r, g, b, a2))

# ─────────────────────────────────────────────────────────────────────────────
# Group-specific neighbor mapping

def _ops_and_offsets(kind: str, i: int, j: int, cw: int, ch: int):
    gx = int(round(cw * GLIDE_X_FRAC))
    gy = int(round(ch * GLIDE_Y_FRAC))
    stagger = int(round(cw * ROW_STAGGER_FRAC))

    ops: List[str] = []
    inner = (0, 0)
    row_stagger_x = 0

    # ── the original 8 ──────────────────────────────────────────────────────
    if kind == "pm":
        if (i % 2) == 1: ops.append("H")

    elif kind == "pg":
        if (i % 2) == 1: ops.append("V"); inner = (0, gy)

    elif kind == "p2":
        if ((i + j) % 2) == 1: ops.append("R180")

    elif kind == "pmm":
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    elif kind == "pgg":
        if (i % 2) == 1: ops.append("V"); inner = (inner[0], inner[1] + gy)
        if (j % 2) == 1: ops.append("H"); inner = (inner[0] + gx, inner[1])

    elif kind == "cmm":
        if (j % 2) == 1: row_stagger_x = stagger
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    elif kind == "p4":
        u, v = (i % 2), (j % 2)
        if   (u, v) == (1, 0): ops.append("R90")
        elif (u, v) == (1, 1): ops.append("R180")
        elif (u, v) == (0, 1): ops.append("R270")

    elif kind == "pmg":
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V"); inner = (inner[0] + gx, inner[1])

    # ── the added 9 ─────────────────────────────────────────────────────────
    elif kind == "p1":
        pass  # translation only

    elif kind == "cm":
        if (j % 2) == 1: row_stagger_x = stagger
        if (i % 2) == 1: ops.append("H")

    elif kind == "p4m":
        u, v = (i % 2), (j % 2)
        if   (u, v) == (1, 0): ops.append("R90")
        elif (u, v) == (1, 1): ops.append("R180")
        elif (u, v) == (0, 1): ops.append("R270")
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    elif kind == "p4g":
        u, v = (i % 2), (j % 2)
        if   (u, v) == (1, 0): ops.append("R90")
        elif (u, v) == (1, 1): ops.append("R180")
        elif (u, v) == (0, 1): ops.append("R270")
        if (i % 2) == 1: inner = (inner[0] + gx, inner[1])
        if (j % 2) == 1: inner = (inner[0], inner[1] + gy)

    elif kind == "p3":
        if (j % 2) == 1: row_stagger_x = stagger
        phase = (i + j) % 3
        if   phase == 1: ops.append("R120")
        elif phase == 2: ops.append("R240")

    elif kind == "p3m1":
        if (j % 2) == 1: row_stagger_x = stagger
        phase = (i + j) % 3
        if   phase == 1: ops.append("R120")
        elif phase == 2: ops.append("R240")
        if (i % 2) == 1: ops.append("H")

    elif kind == "p31m":
        if (j % 2) == 1: row_stagger_x = stagger
        phase = (i + j) % 3
        if   phase == 1: ops.append("R120")
        elif phase == 2: ops.append("R240")
        if (j % 2) == 1: ops.append("V")

    elif kind == "p6":
        if (j % 2) == 1: row_stagger_x = stagger
        phase = (i + j) % 6
        if   phase == 1: ops.append("R60")
        elif phase == 2: ops.append("R120")
        elif phase == 3: ops.append("R180")
        elif phase == 4: ops.append("R240")
        elif phase == 5: ops.append("R300")

    elif kind == "p6m":
        if (j % 2) == 1: row_stagger_x = stagger
        phase = (i + j) % 6
        if   phase == 1: ops.append("R60")
        elif phase == 2: ops.append("R120")
        elif phase == 3: ops.append("R180")
        elif phase == 4: ops.append("R240")
        elif phase == 5: ops.append("R300")
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    return ops, inner, row_stagger_x



def _apply_ops(im: Image.Image, ops: dict) -> Image.Image:
    im = im.convert("RGBA")
    if ops.get("flip_h"): im = flip_h(im)
    if ops.get("flip_v"): im = flip_v(im)
    ang = int(ops.get("rot", 0)) % 360
    if ang:
        im = im.rotate(ang, resample=Image.BICUBIC, expand=True)
    # re‑crop (pad a hair to protect thin rims on circles)
    return tight_crop_rgba(im, pad=2)  # :contentReference[oaicite:10]{index=10}

# --- compute worst‑case bbox over the ops used by a wallpaper group
def _max_bbox_under_ops(base: Image.Image, ops_list: list[dict]) -> tuple[int,int]:
    W = H = 0
    for ops in ops_list:
        t = _apply_ops(base, ops)
        W, H = max(W, t.width), max(H, t.height)
    return W, H

# --- place an item inside a slot with clamping & RGBA compositing
def _place_in_slot(slot_w: int, slot_h: int, bmp: Image.Image, dx: int, dy: int) -> Image.Image:
    slot = Image.new("RGBA", (slot_w, slot_h), (0,0,0,0))
    x = (slot_w - bmp.width)//2 + dx
    y = (slot_h - bmp.height)//2 + dy
    # clamp so nothing ever bleeds outside the slot
    x = max(0, min(slot_w - bmp.width,  x))
    y = max(0, min(slot_h - bmp.height, y))
    slot.alpha_composite(bmp, (x, y))
    return slot

# ─────────────────────────────────────────────────────────────────────────────
# Degeneracy guard (same logic as frieze; only check the ops required by kind)

def _wallpaper_degenerate_for(motif_im: Image.Image, kind: str) -> bool:
    base = _scale_longer_side_crisp(motif_im, max(12, int(SS_CELL * 0.22)))
    checks: List[Image.Image] = []

    # mirrors
    if kind in ("pm", "pmm", "pgg", "cmm", "pmg", "cm", "p4m", "p3m1", "p31m", "p6m"):
        checks.append(flip_h(base))
    if kind in ("pg", "pmm", "pgg", "cmm", "pmg", "p4m", "p31m", "p6m"):
        checks.append(flip_v(base))

    # 180°
    if kind in ("p2", "pmm", "pgg", "cmm", "p4", "p4m", "p4g", "p6", "p6m"):
        checks.append(rot180(base))

    # 90° set
    if kind in ("p4", "p4m", "p4g"):
        checks.extend([rot90(base), rot270(base)])

    # 120° set
    if kind in ("p3", "p3m1", "p31m", "p6", "p6m"):
        checks.extend([_rot120(base), _rot240(base)])

    # 60° set
    if kind in ("p6", "p6m"):
        checks.extend([_rot60(base), _rot300(base)])

    for ref in checks:
        if diff_frac(base, ref, thresh=8) < max(0.006, 0.5 * OPT_MIN_DELTA):
            return True
    return False

def _render_wallpaper_tile(base_item: Image.Image,
                           group_kind: str,
                           rows: int, cols: int,
                           slot_gap_px: int,
                           row_offsets: list[int],  # one per row
                           col_offsets: list[int],  # one per col (if used)
                           ops_grid: list[list[dict]],  # [j][i] → dict like {"rot":90,"flip_h":1}
                           inner_anchor: str = "center") -> Image.Image:
    """
    Build a single wallpaper tile (RGBA) at supersampled resolution (SS_CELL),
    then downsample and add the crisp inner border.
    """

    # --- per-item slot size from rows/cols & gap (same layout math as your sequence task). :contentReference[oaicite:11]{index=11}
    gap = max(6, SS_CELL // 40)
    margin = max(10, SS_CELL // 20)
    w_slot = (SS_CELL - 2*margin - (cols - 1)*gap) // cols
    h_slot = (SS_CELL - 2*margin - (rows - 1)*gap) // rows

    # --- 1) pick all ops used by this tile and scale base to the worst case
    ops_list = [ops_grid[j][i] for j in range(rows) for i in range(cols)]
    raw = tight_crop_rgba(base_item)                 # keep RGBA, remove empty margins  :contentReference[oaicite:12]{index=12}
    max_w, max_h = _max_bbox_under_ops(raw, ops_list)
    # leave safety for within-slot inner offsets
    max_inner_dx = max(abs(o.get("dx", 0)) for o in ops_list) if ops_list else 0
    max_inner_dy = max(abs(o.get("dy", 0)) for o in ops_list) if ops_list else 0
    safe_w = max(8, w_slot - 2*max_inner_dx - 2)    # 2px rim
    safe_h = max(8, h_slot - 2*max_inner_dy - 2)
    s = min(1.0, safe_w / float(max_w or 1), safe_h / float(max_h or 1))
    base_scaled = _thicken_alpha_1px(_scale_longer_side_crisp(raw, int(round(max(max_w, max_h)*s))))

    # --- 2) compose all cells into a transparent canvas with row/col offsets
    big = Image.new("RGBA", (SS_CELL, SS_CELL), (0,0,0,0))
    for j in range(rows):
        for i in range(cols):
            ops = ops_grid[j][i]
            # transforms (never crop) + recrop
            item = _apply_ops(base_scaled, ops)
            # per-cell inner offset (within the slot)
            dx = int(ops.get("dx", 0))
            dy = int(ops.get("dy", 0))
            slot = _place_in_slot(w_slot, h_slot, item, dx, dy)
            # slot top-left in the big canvas (+ constant row/col offsets)
            x0 = margin + i * (w_slot + gap) + int(row_offsets[j] if j < len(row_offsets) else 0)
            y0 = margin + j * (h_slot + gap) + int(col_offsets[i] if i < len(col_offsets) else 0)
            # 3) RGBA composite (never plain paste)
            big.alpha_composite(slot, (x0, y0))

    # --- inner crisp border and downsample to OUT_CELL
    border_ss = max(3, SS_CELL // 160)
    ImageDraw.Draw(big).rectangle([0, 0, SS_CELL - 1, SS_CELL - 1], outline=(0,0,0), width=border_ss)
    white = Image.new("RGBA", (SS_CELL, SS_CELL), (255,255,255,255))
    return Image.alpha_composite(white, big).resize((OUT_CELL, OUT_CELL), Image.LANCZOS).convert("RGBA")

def _ops_and_offsets_amp(kind: str, i: int, j: int, cw: int, ch: int, gx: int, gy: int):
    ops: List[str] = []
    inner = (0, 0)
    row_stagger_x = 0

    # ── the original 8 ──────────────────────────────────────────────────────
    if kind == "pm":
        if (i % 2) == 1: ops.append("H")

    elif kind == "pg":
        if (i % 2) == 1: ops.append("V"); inner = (0, gy)

    elif kind == "p2":
        if ((i + j) % 2) == 1: ops.append("R180")

    elif kind == "pmm":
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    elif kind == "pgg":
        if (i % 2) == 1: ops.append("V"); inner = (inner[0], inner[1] + gy)
        if (j % 2) == 1: ops.append("H"); inner = (inner[0] + gx, inner[1])

    elif kind == "cmm":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    elif kind == "p4":
        u, v = (i % 2), (j % 2)
        if   (u, v) == (1, 0): ops.append("R90")
        elif (u, v) == (1, 1): ops.append("R180")
        elif (u, v) == (0, 1): ops.append("R270")

    elif kind == "pmg":
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V"); inner = (inner[0] + gx, inner[1])

    # ── the added 9 ─────────────────────────────────────────────────────────
    elif kind == "p1":
        pass

    elif kind == "cm":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        if (i % 2) == 1: ops.append("H")

    elif kind == "p4m":
        u, v = (i % 2), (j % 2)
        if   (u, v) == (1, 0): ops.append("R90")
        elif (u, v) == (1, 1): ops.append("R180")
        elif (u, v) == (0, 1): ops.append("R270")
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    elif kind == "p4g":
        u, v = (i % 2), (j % 2)
        if   (u, v) == (1, 0): ops.append("R90")
        elif (u, v) == (1, 1): ops.append("R180")
        elif (u, v) == (0, 1): ops.append("R270")
        if (i % 2) == 1: inner = (inner[0] + gx, inner[1])
        if (j % 2) == 1: inner = (inner[0], inner[1] + gy)

    elif kind == "p3":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        phase = (i + j) % 3
        if   phase == 1: ops.append("R120")
        elif phase == 2: ops.append("R240")

    elif kind == "p3m1":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        phase = (i + j) % 3
        if   phase == 1: ops.append("R120")
        elif phase == 2: ops.append("R240")
        if (i % 2) == 1: ops.append("H")

    elif kind == "p31m":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        phase = (i + j) % 3
        if   phase == 1: ops.append("R120")
        elif phase == 2: ops.append("R240")
        if (j % 2) == 1: ops.append("V")

    elif kind == "p6":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        phase = (i + j) % 6
        if   phase == 1: ops.append("R60")
        elif phase == 2: ops.append("R120")
        elif phase == 3: ops.append("R180")
        elif phase == 4: ops.append("R240")
        elif phase == 5: ops.append("R300")

    elif kind == "p6m":
        if (j % 2) == 1: row_stagger_x = int(round(cw * ROW_STAGGER_FRAC))
        phase = (i + j) % 6
        if   phase == 1: ops.append("R60")
        elif phase == 2: ops.append("R120")
        elif phase == 3: ops.append("R180")
        elif phase == 4: ops.append("R240")
        elif phase == 5: ops.append("R300")
        if (i % 2) == 1: ops.append("H")
        if (j % 2) == 1: ops.append("V")

    return ops, inner, row_stagger_x



# ─────────────────────────────────────────────────────────────────────────────

def _render_patch_HR(
    base_item: Image.Image,
    kind: str,
    *,
    cell_w: int,   # HR cell size (pixels)
    cell_h: int,
    cols: int,
    rows: int,
    edge: int,     # reserved (border drawn later by caller)
) -> Tuple[Image.Image, List[Image.Image]]:
    """
    Build a *borderless* RGBA patch at high resolution (HR).
    We (1) fit the base motif into the cell, (2) derive glide amplitudes
    from the remaining free margin (so we never force a shrink), and
    (3) clamp every placement inside its slot to avoid cut-offs/bleed.
    Returns (patch_rgba, reps) where reps is unused here.
    """
    # --- 0) prepare base (transparent + tight crop)
    m0 = base_item.convert("RGBA")
    # Safety: zap edge-connected near-white if any & crop; harmless if already RGBA-clean
    m0 = _knockout_edge_white_rgba(m0)
    m0 = tight_crop_rgba(m0)

    # --- 1) fit motif by MOTIF_FIT_FRAC (never pre-shrink for glides)
    target_w = int(round(cell_w * MOTIF_FIT_FRAC))
    target_h = int(round(cell_h * MOTIF_FIT_FRAC))
    m0 = _scale_longer_side_crisp(m0, min(target_w, target_h))
    if m0.width > target_w or m0.height > target_h:
        s = min(target_w / m0.width, target_h / m0.height)
        m0 = m0.resize(
            (max(8, int(round(m0.width * s))), max(8, int(round(m0.height * s)))),
            Image.LANCZOS
        )
    m0 = _thicken_alpha_1px(m0)

    # --- 2) choose glide amplitudes that *fit* inside the free margin
    rim = 2  # small safety rim so circles never touch the slot border
    free_x = max(0, (cell_w - m0.width)  // 2 - rim)
    free_y = max(0, (cell_h - m0.height) // 2 - rim)
    gx_eff = min(int(round(cell_w * GLIDE_X_FRAC)), free_x)
    gy_eff = min(int(round(cell_h * GLIDE_Y_FRAC)), free_y)

    # --- 3) canvas large enough for the worst row-stagger; keep transparent
    max_row_stagger = int(round(cell_w * ROW_STAGGER_FRAC))
    inner_w = cols * cell_w + max_row_stagger
    inner_h = rows * cell_h
    big = Image.new("RGBA", (inner_w, inner_h), (255, 255, 255, 0))

    # --- 4) place cells (row-stagger computed once per row; transforms expand=True)
    for j in range(rows):
        # constant offset for this row
        _, _, row_off = _ops_and_offsets_amp(kind, 0, j, cell_w, cell_h, gx_eff, gy_eff)
        for i in range(cols):
            ops_list, inner_off, _ = _ops_and_offsets_amp(kind, i, j, cell_w, cell_h, gx_eff, gy_eff)
            m = _apply_ops_list(m0, ops_list)  # mirror/flip, then rotate(expand), then tight crop

            # draw into a private slot, clamped to stay inside the slot
            slot = Image.new("RGBA", (cell_w, cell_h), (255, 255, 255, 0))
            x_in = (cell_w - m.width)  // 2 + int(inner_off[0])
            y_in = (cell_h - m.height) // 2 + int(inner_off[1])
            x_in = max(0, min(cell_w - m.width,  x_in))
            y_in = max(0, min(cell_h - m.height, y_in))
            slot.alpha_composite(m, (x_in, y_in))

            # slot top-left in the big canvas
            x = i * cell_w + row_off
            y = j * cell_h
            big.alpha_composite(slot, (x, y))

    # Keep patch transparent; border/flattening happens later
    return big, []


# ─────────────────────────────────────────────────────────────────────────────
@register_task
class SymmetryWallpaperGroupsTask(Task):
    """
    Four wallpaper patches (same motif family). Exactly three use one 2D rule; the fourth uses a different rule.
    Display format: fixed 2×2 grid. Bordered patches are *perfect squares* and identically sized.
    """
    name = "symmetry_wallpaper_groups"

    def __init__(self):
        self.max_retries = int(MAX_BUILD_RETRIES)
        self.opt_hash_min_bits = int(OPT_HASH_MIN_BITS)
        self.opt_min_delta = float(OPT_UNIQUENESS_MIN)

    def generate_instance(self, motif_impls: Dict[str, Any], rng: random.Random):
        # ---- Choose motif family with the same weights as frieze ----  :contentReference[oaicite:12]{index=12} :contentReference[oaicite:13]{index=13}
        allowed = [k for k in motif_impls.keys() if MOTIF_WEIGHTS.get(k, 0) > 0] or list(motif_impls.keys())
        if not allowed: raise RuntimeError(f"{self.name}: no motifs available.")
        items = []
        for k in allowed:
            w = float(MOTIF_WEIGHTS[k]); u = max(rng.random(), 1e-12)
            items.append((u ** (1.0 / w), k))
        items.sort(reverse=True)
        motif_order = [k for _, k in items]

        # ---- Wallpaper kinds ----
        kinds_all = [k for k, w in WALLPAPER_WEIGHTS.items() if w > 0.0]
        kitems = []
        for k in kinds_all:
            w = float(WALLPAPER_WEIGHTS[k]); u = max(rng.random(), 1e-12)
            kitems.append((u ** (1.0 / w), k))
        kitems.sort(reverse=True)
        kind_order = [k for _, k in kitems]

        labels = labels_default()
        font = load_font()

        # Shared base cell geometry at SS scale
        cell_h0 = max(12, int(round(SS_CELL * CELL_BASE_SCALE_FRAC)))
        margin_soft = max(10, SS_CELL // 20)
        max_inner_w = SS_CELL - 2*margin_soft
        max_inner_h = SS_CELL - 2*margin_soft

        grid_cols = rng.randint(GRID_SIZE_RANGE[0], GRID_SIZE_RANGE[1])
        grid_rows = grid_cols

        for mk in motif_order:
            motif = motif_impls[mk]
            for _ in range(self.max_retries):
                try:
                    specs = [_prefer_asym_mode(motif, motif.sample_spec(rng)) for _ in range(4)]
                    raws = [tight_crop_rgba(motif.render(s)) for s in specs]
                except Exception:
                    continue
                if any(im.width < 8 or im.height < 8 for im in raws):
                    continue

                common_kind = choice_weighted(rng, kind_order, [WALLPAPER_WEIGHTS[k] for k in kind_order])
                odd_choices = [k for k in kinds_all if k != common_kind]
                odd_kind = choice_weighted(rng, odd_choices, [WALLPAPER_WEIGHTS[k] for k in odd_choices])

                # Guard against degenerate symmetry for BOTH rules
                if any(_wallpaper_degenerate_for(im, common_kind) for im in raws): continue
                if any(_wallpaper_degenerate_for(im, odd_kind)   for im in raws): continue

                # Cell size that fits our SS tile
                tmp = [_scale_longer_side_crisp(im, int(round(cell_h0*MOTIF_FIT_FRAC))) for im in raws]
                ref_w = max(int(im.width / max(1e-3, MOTIF_FIT_FRAC)) for im in tmp)
                ref_w = max(ref_w, cell_h0)
                need_w = ref_w * grid_cols + int(round(ref_w * ROW_STAGGER_FRAC))
                need_h = cell_h0 * grid_rows
                s = min(1.0, min(max_inner_w / max(1, need_w), max_inner_h / max(1, need_h)))
                cell_w_ss = max(12, int(round(ref_w   * s)))
                cell_h_ss = max(12, int(round(cell_h0 * s)))

                # High-resolution (HR) geometry (no later downscale)
                HR = max(1.0, float(PATCH_TILE_SCALE))
                cell_w = max(12, int(round(cell_w_ss * HR)))
                cell_h = max(12, int(round(cell_h_ss * HR)))
                edge   = max(3, SS_CELL // 160)
                edgeHR = max(2, int(round(edge * HR)))

                # Assign odd slot and build all four borderless patches
                odd_slot = rng.randrange(4)
                kinds_for_slot = [common_kind, common_kind, common_kind, common_kind]
                kinds_for_slot[odd_slot] = odd_kind

                patch_rgba: List[Image.Image] = []
                ok = True
                for i in range(4):
                    p_rgba, _ = _render_patch_HR(
                        raws[i], kinds_for_slot[i],
                        cell_w=cell_w, cell_h=cell_h, cols=grid_cols, rows=grid_rows, edge=edgeHR
                    )
                    patch_rgba.append(p_rgba)

                if not ok or len(patch_rgba) != 4:
                    continue

                # Enforce *square* bordered tiles with same side for all 4
                max_w = max(p.width  for p in patch_rgba)
                max_h = max(p.height for p in patch_rgba)
                side  = max(max_w, max_h)

                square_bordered: List[Image.Image] = []
                for p in patch_rgba:
                    sq = Image.new("RGBA", (side, side), (255,255,255,0))
                    sq.alpha_composite(p, ((side - p.width)//2, (side - p.height)//2))
                    # draw border at HR
                    draw = ImageDraw.Draw(sq)
                    draw.rectangle([0, 0, sq.width-1, sq.height-1], outline=(0,0,0), width=edgeHR-1)
                    # now (and only now) flatten to white
                    white = Image.new("RGBA", (sq.width, sq.height), (255,255,255,255))
                    sq_flat = Image.alpha_composite(white, sq).convert("RGBA")
                    square_bordered.append(sq_flat)

                # Distinctness checks across tiles (avoid near duplicates)
                for i in range(4):
                    for j in range(i + 1, 4):
                        if diff_frac(square_bordered[i], square_bordered[j], thresh=8) < OPT_MIN_DELTA:
                            ok = False; break
                    if not ok: break
                if not ok: continue

                # Label band (uniform metrics)
                common_h = max(im.height for im in square_bordered)
                label_max_px_hr = max(LABEL_MAX_PX, int(round(LABEL_MAX_PX * HR)))
                target_label_h = max(LABEL_MIN_PX, min(int(round(common_h * LABEL_FRAC)), label_max_px_hr))
                label_pad_y = max(LABEL_PAD_MIN, int(round(common_h * LABEL_PAD_FRAC)))

                def _label_tile_compact(tile: Image.Image, label: str) -> Image.Image:
                    # Render crisp label below tile without resampling the tile
                    txt = f"({label})" if not (label.startswith("(") and label.endswith(")")) else label
                    # super-sample factor reuses HR
                    k = max(2, int(round((SS_CELL / max(1, OUT_CELL)) * PATCH_TILE_SCALE)))
                    try:
                        x0,y0,x1,y1 = font.getbbox(txt)
                        base_h = max(1, y1-y0)
                        f_hr = font.font_variant(size=max(6, int(round(getattr(font,"size",16) * (target_label_h*k/base_h)))))
                    except Exception:
                        f_hr = font
                    x0,y0,x1,y1 = f_hr.getbbox(txt)
                    tw_hr, th_hr = (x1-x0), (y1-y0)
                    lab_hr = Image.new("RGBA", (max(1, tw_hr), max(1, th_hr)), (255,255,255,0))
                    draw = ImageDraw.Draw(lab_hr)
                    stroke_w = max(1, k // 2)
                    draw.text((-x0, -y0), txt, fill=(0,0,0), font=f_hr, stroke_width=stroke_w, stroke_fill=(0,0,0))
                    tw = max(1, int(round(tw_hr / k)))
                    th = max(1, int(round(th_hr / k)))
                    lab = lab_hr.resize((tw, th), Image.LANCZOS)
                    W = max(tile.width, lab.width)
                    H = tile.height + label_pad_y + lab.height
                    out = Image.new("RGBA", (W, H), (255,255,255,255))
                    out.paste(tile, ((W - tile.width)//2, 0), tile)
                    out.paste(lab, ((W - lab.width)//2, tile.height + label_pad_y), lab)
                    return out

                labeled = [_label_tile_compact(im, lab) for im, lab in zip(square_bordered, labels)]

                # Pad labeled tiles to identical size (centered); content stays crisp
                cell_w2 = max(im.width for im in labeled)
                cell_h2 = max(im.height for im in labeled)
                padded = []
                for im in labeled:
                    out = Image.new("RGBA", (cell_w2, cell_h2), (255,255,255,255))
                    out.paste(im, ((cell_w2 - im.width)//2, (cell_h2 - im.height)//2), im)
                    padded.append(out)

                # Compose 2×2 grid
                def _compose_grid_2x2(tiles: List[Image.Image]) -> Image.Image:
                    cw = max(im.width for im in tiles)
                    ch = max(im.height for im in tiles)
                    gap_x = max(6,  min(14, cw // 36))
                    gap_y = max(4,  min(10, ch // 36))
                    margin_x = max(6,  min(12, cw // 40))
                    margin_y = max(6,  min(12, ch // 40))
                    W = margin_x + cw + gap_x + cw + margin_x
                    H = margin_y + ch + gap_y + ch + margin_y
                    canvas = Image.new("RGBA", (W, H), (255,255,255,255))
                    cells = [
                        (margin_x,                   margin_y),
                        (margin_x + cw + gap_x,     margin_y),
                        (margin_x,                   margin_y + ch + gap_y),
                        (margin_x + cw + gap_x,     margin_y + ch + gap_y)
                    ]
                    for im, (cx, cy) in zip(tiles, cells):
                        x = cx + (cw - im.width)//2
                        y = cy + (ch - im.height)//2
                        canvas.paste(im, (x, y), im)
                    return canvas

                composite = _compose_grid_2x2(padded)

                question = rng.choice(PROMPT_TEMPLATES)
                answer_label = labels[odd_slot]

                meta = {
                    "pattern_kind": "sequence",
                    "pattern": self.name,
                    "grid": (2, 2),
                    "motif_kind": mk,
                    "labels": labels,
                    "answer": answer_label,
                    "question": question,
                    "composite_ready": True,
                    "common_kind": common_kind,
                    "common_iuc": WALLPAPER_IUC.get(common_kind, ""),
                    "odd_kind": odd_kind,
                    "odd_iuc": WALLPAPER_IUC.get(odd_kind, ""),
                    "odd_index": int(odd_slot),
                    "patch": {
                        "rows": int(grid_rows), "cols": int(grid_cols),
                        "cell_w_hr": int(cell_w), "cell_h_hr": int(cell_h),
                    },
                }
                payloads = [{
                    "slot": i, "is_odd": (i == odd_slot),
                    "wallpaper_kind": kinds_for_slot[i],
                    "wallpaper_iuc": WALLPAPER_IUC.get(kinds_for_slot[i], ""),
                } for i in range(4)]

                return composite, payloads, meta

        raise RuntimeError(f"{self.name}: failed to produce an instance after {self.max_retries} attempts.")
