# -*- coding: utf-8 -*-


import os, random
from typing import List, Tuple
from PIL import Image, ImageOps, ImageFilter, ImageDraw, ImageFont

# ===================== 基础配置 =====================
IMAGE_PATH    = "/data1/lyl/jailbreak/00049999.jpg"
OUT_DIR       = "/data1/lyl/jailbreak/puzzle2"
os.makedirs(OUT_DIR, exist_ok=True)

TEXT_TO_DRAW  = "B"
RANDOM_SEED   = 2025

# 文本框与样式
TEXT_BOX_POS       = ("left", "top")     # ('left'|'center'|'right', 'top'|'middle'|'bottom')
TEXT_BOX_W_FRAC    = 0.68
TEXT_BOX_H_FRAC    = 0.68
TEXT_MARGIN_FRAC   = 0.04
TEXT_ALIGN         = "left"
TEXT_FILL          = (255, 255, 255, 255)
STROKE_FILL        = (0, 0, 0, 255)
STROKE_WIDTH       = 3
LINE_SPACING       = 0.15
DROP_SHADOW_TEXT   = True
SHADOW_OFFSET_TEXT = (2, 2)
SHADOW_BLUR_TEXT   = 4
SHADOW_ALPHA_TEXT  = 140
PREFERRED_FONTS = [
    "/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
    "/System/Library/Fonts/Supplemental/Arial.ttf",
    "/Library/Fonts/Arial Unicode.ttf",
    "/Library/Fonts/Arial.ttf",
    "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
    "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
    "C:/Windows/Fonts/arial.ttf",
]

# gird
GRID_ROWS = 3
GRID_COLS = 3


KNOB_RADIUS      = 28
NECK_FRAC        = 0.6
EDGE_ROUNDING    = 6
SAFETY_MARGIN    = 4   

ADD_BORDER           = True
BORDER_PX            = 4
ADD_PIECE_SHADOW     = True
PIECE_SHADOW_OFFSET  = (6, 6)
PIECE_SHADOW_BLUR    = 12
PIECE_SHADOW_ALPHA   = 120

ALLOW_ROTATION       = True
MAX_ROTATE_DEG       = 8
CELL_GAP             = 28
CELL_INNER_JITTER    = 16
CANVAS_PADDING       = 40
BACKGROUND_RGBA      = (255, 255, 255, 255)

# 整齐矩阵
GRID_CANVAS_PAD      = 36
GRID_GAP             = 24

# ===================== text tools =====================
def load_font(size: int) -> ImageFont.FreeTypeFont:
    for p in PREFERRED_FONTS:
        try: return ImageFont.truetype(p, size=size)
        except Exception: pass
    return ImageFont.load_default()

def is_cjk(ch: str) -> bool:
    code = ord(ch)
    return (0x4E00 <= code <= 0x9FFF) or (0x3400 <= code <= 0x4DBF) or (0x20000 <= code <= 0x2A6DF) or \
           (0x2A700 <= code <= 0x2B73F) or (0x2B740 <= code <= 0x2B81F) or (0x2B820 <= code <= 0x2CEAF) or \
           (0xF900 <= code <= 0xFAFF) or (0x2F800 <= code <= 0x2FA1F)

def tokenize_for_wrap(text: str):
    tokens, buf, mode = [], "", None
    for ch in text:
        if ch.isspace():
            if buf: tokens.append(buf); buf=""; mode=None
            tokens.append(" "); continue
        now = "cjk" if is_cjk(ch) else "latin"
        if mode is None: buf, mode = ch, now
        elif mode == now:
            if now == "cjk": tokens.append(ch)
            else: buf += ch
        else:
            if buf: tokens.append(buf)
            buf, mode = ch, now
        if now == "cjk": buf=""; mode=None
    if buf: tokens.append(buf)
    return tokens

def measure_lines(draw, lines, font, spacing, stroke_w):
    max_w, total_h = 0, 0
    for i, line in enumerate(lines):
        bbox = draw.textbbox((0,0), line, font=font, stroke_width=stroke_w)
        w = bbox[2]-bbox[0]; h=bbox[3]-bbox[1]
        max_w = max(max_w, w); total_h += h
        if i < len(lines)-1: total_h += spacing
    return max_w, total_h

def wrap_text(text, font, max_w, draw, stroke_w):
    tokens = tokenize_for_wrap(text)
    lines, cur = [], ""
    for tok in tokens:
        probe = (cur + tok) if cur else tok
        w,_ = measure_lines(draw, [probe], font, 0, stroke_w)
        if tok == " ":
            if cur: cur += tok
            continue
        if w <= max_w:
            cur = probe
        else:
            if cur: lines.append(cur.rstrip()); cur=""
            if tok != " ":
                buf=""
                for ch in tok:
                    w2,_=measure_lines(draw,[buf+ch],font,0,stroke_w)
                    if w2<=max_w: buf+=ch
                    else:
                        if buf: lines.append(buf); buf=ch
                        else: lines.append(ch); buf=""
                if buf: cur=buf
    if cur: lines.append(cur.rstrip())
    return lines

def find_max_font(img_w, img_h, text, box_wh, margin_px, stroke_w, line_spacing_ratio):
    draw = ImageDraw.Draw(Image.new("RGBA",(img_w,img_h)))
    lo, hi = 8, max(20, min(img_w, img_h))
    chosen_lines, chosen_font = [], load_font(lo)
    while lo <= hi:
        mid = (lo+hi)//2
        font = load_font(mid)
        spacing = max(1,int(mid*line_spacing_ratio))
        lines = wrap_text(text, font, box_wh[0]-2*margin_px, draw, stroke_w)
        w,h = measure_lines(draw, lines, font, spacing, stroke_w)
        if w+2*margin_px<=box_wh[0] and h+2*margin_px<=box_wh[1]:
            chosen_lines, chosen_font = lines, font
            lo = mid+1
        else:
            hi = mid-1
    return hi, chosen_lines, chosen_font

def draw_text_block(base: Image.Image, text: str) -> Image.Image:
    img = base.convert("RGBA"); W,H = img.size
    box_w = int(W*TEXT_BOX_W_FRAC); box_h = int(H*TEXT_BOX_H_FRAC)
    short = min(W,H); margin_px = max(2,int(short*TEXT_MARGIN_FRAC))
    xa, ya = TEXT_BOX_POS
    bx = margin_px if xa=="left" else ((W-box_w)//2 if xa=="center" else W-box_w-margin_px)
    by = margin_px if ya=="top"  else ((H-box_h)//2 if ya=="middle" else H-box_h-margin_px)
    best_size, lines, font = find_max_font(W,H,text,(box_w,box_h),margin_px,STROKE_WIDTH,LINE_SPACING)
    spacing = max(1,int(best_size*LINE_SPACING))
    tmp = Image.new("RGBA", img.size, (0,0,0,0))
    d = ImageDraw.Draw(tmp)
    w_all,h_all = measure_lines(d, lines, font, spacing, STROKE_WIDTH)
    tx = bx+margin_px if TEXT_ALIGN=="left" else (bx+(box_w-w_all)//2 if TEXT_ALIGN=="center" else bx+box_w-w_all-margin_px)
    ty = by+(box_h-h_all)//2

    if DROP_SHADOW_TEXT and SHADOW_ALPHA_TEXT>0 and SHADOW_BLUR_TEXT>0:
        sh = Image.new("RGBA", img.size, (0,0,0,0)); d2 = ImageDraw.Draw(sh)
        cy=ty
        for line in lines:
            d2.text((tx+SHADOW_OFFSET_TEXT[0], cy+SHADOW_OFFSET_TEXT[1]), line,
                    font=font, fill=(0,0,0,SHADOW_ALPHA_TEXT),
                    stroke_width=STROKE_WIDTH, stroke_fill=(0,0,0,SHADOW_ALPHA_TEXT))
            bbox = d2.textbbox((tx,cy), line, font=font, stroke_width=STROKE_WIDTH)
            cy += (bbox[3]-bbox[1]) + spacing
        sh = sh.filter(ImageFilter.GaussianBlur(SHADOW_BLUR_TEXT))
        img = Image.alpha_composite(img, sh)

    cy=ty
    for line in lines:
        d.text((tx,cy), line, font=font, fill=TEXT_FILL,
               stroke_width=STROKE_WIDTH, stroke_fill=STROKE_FILL, align=TEXT_ALIGN)
        bbox = d.textbbox((tx,cy), line, font=font, stroke_width=STROKE_WIDTH)
        cy += (bbox[3]-bbox[1]) + spacing
    return Image.alpha_composite(img, tmp)

# ===================== Jigsaw shape =====================
def split_grid_rects(W:int, H:int, rows:int, cols:int):
    xs = [round(i*W/cols) for i in range(cols+1)]
    ys = [round(i*H/rows) for i in range(rows+1)]
    rects=[]
    for r in range(rows):
        for c in range(cols):
            rects.append((xs[c], ys[r], xs[c+1], ys[r+1]))
    return rects

def gen_edge_signs(rows:int, cols:int, seed:int=2025):
    rnd = random.Random(seed)
    h_dir = [[0]*cols for _ in range(rows+1)]
    v_dir = [[0]*(cols+1) for _ in range(rows)]
    for r in range(1, rows):
        for c in range(cols):
            h_dir[r][c] = rnd.choice([-1,1])
    for r in range(rows):
        for c in range(1, cols):
            v_dir[r][c] = rnd.choice([-1,1])
    return h_dir, v_dir

def build_piece_mask(rect, W, H, rows, cols, r, c, h_dir, v_dir,
                     knob_r:int, neck_frac:float, edge_rounding:int,
                     border_px:int, safety:int):
    """
    根据当前块四边的“外凸/内凹”情况，分别计算四个方向的裁剪余量。
    外凸侧 pad = 2*knob_r + border_px + safety；内凹/平直为 0。
    """
    x0,y0,x1,y1 = rect
    cw, ch = x1-x0, y1-y0

    
    top_convex = (r>0 and h_dir[r][c] != +1)          
    bot_convex = (r<rows-1 and h_dir[r+1][c] == +1)    
    left_convex = (c>0 and v_dir[r][c] != +1)          
    right_convex = (c<cols-1 and v_dir[r][c+1] == +1)  

    ext = 2*knob_r + border_px + safety
    pad_top    = ext if top_convex  else 0
    pad_bottom = ext if bot_convex  else 0
    pad_left   = ext if left_convex else 0
    pad_right  = ext if right_convex else 0

   
    crop = (max(0, x0 - pad_left),
            max(0, y0 - pad_top),
            min(W, x1 + pad_right),
            min(H, y1 + pad_bottom))
    crop_w, crop_h = crop[2]-crop[0], crop[3]-crop[1]

  
    rx0, ry0 = x0 - crop[0], y0 - crop[1]
    rx1, ry1 = rx0 + cw, ry0 + ch

  
    mask = Image.new("L", (crop_w, crop_h), 0)
    d = ImageDraw.Draw(mask)
    if edge_rounding>0:
        d.rounded_rectangle([rx0,ry0,rx1,ry1], radius=edge_rounding, fill=255)
    else:
        d.rectangle([rx0,ry0,rx1,ry1], fill=255)

    neck = knob_r*neck_frac
    cx_mid = lambda: (rx0 + rx1)/2
    cy_mid = lambda: (ry0 + ry1)/2

  
    if r>0:
        sign = h_dir[r][c]
        cx = cx_mid()
        if sign==+1: 
            d.rectangle([cx-neck/2, ry0, cx+neck/2, ry0+knob_r], fill=0)
            d.ellipse([cx-knob_r, ry0, cx+knob_r, ry0+2*knob_r], fill=0)
        else:        
            d.rectangle([cx-neck/2, ry0-knob_r, cx+neck/2, ry0], fill=255)
            d.ellipse([cx-knob_r, ry0-2*knob_r, cx+knob_r, ry0], fill=255)
 
    if r<rows-1:
        sign = h_dir[r+1][c]
        cx = cx_mid()
        if sign==+1:  
            d.rectangle([cx-neck/2, ry1, cx+neck/2, ry1+knob_r], fill=255)
            d.ellipse([cx-knob_r, ry1, cx+knob_r, ry1+2*knob_r], fill=255)
        else:        
            d.rectangle([cx-neck/2, ry1-knob_r, cx+neck/2, ry1], fill=0)
            d.ellipse([cx-knob_r, ry1-2*knob_r, cx+knob_r, ry1], fill=0)
    
    if c>0:
        sign = v_dir[r][c]
        cy = cy_mid()
        if sign==+1: 
            d.rectangle([rx0, cy-neck/2, rx0+knob_r, cy+neck/2], fill=0)
            d.ellipse([rx0, cy-knob_r, rx0+2*knob_r, cy+knob_r], fill=0)
        else:         
            d.rectangle([rx0-knob_r, cy-neck/2, rx0, cy+neck/2], fill=255)
            d.ellipse([rx0-2*knob_r, cy-knob_r, rx0, cy+knob_r], fill=255)
   
    if c<cols-1:
        sign = v_dir[r][c+1]
        cy = cy_mid()
        if sign==+1:  
            d.rectangle([rx1, cy-neck/2, rx1+knob_r, cy+neck/2], fill=255)
            d.ellipse([rx1, cy-knob_r, rx1+2*knob_r, cy+knob_r], fill=255)
        else:        
            d.rectangle([rx1-knob_r, cy-neck/2, rx1, cy+neck/2], fill=0)
            d.ellipse([rx1-2*knob_r, cy-knob_r, rx1, cy+knob_r], fill=0)

   
    if EDGE_ROUNDING>0:
        mask = mask.filter(ImageFilter.GaussianBlur(0.6))

   
    bbox = mask.getbbox() or (0,0,crop_w,crop_h)
    return mask.crop(bbox), (crop[0]+bbox[0], crop[1]+bbox[1], crop[0]+bbox[2], crop[1]+bbox[3])

def generate_jigsaw_pieces(img: Image.Image, rows:int, cols:int,
                           knob_r:int, neck_frac:float, edge_round:int):
    W,H = img.size
    rects = split_grid_rects(W,H,rows,cols)
    h_dir, v_dir = gen_edge_signs(rows, cols, seed=RANDOM_SEED if RANDOM_SEED is not None else 2025)

    pieces=[]; idx=0
    for r in range(rows):
        for c in range(cols):
            rect = rects[idx]; idx+=1
            mask, tight_crop = build_piece_mask(
                rect, W,H,rows,cols,r,c,h_dir,v_dir,
                knob_r, neck_frac, edge_round, BORDER_PX if ADD_BORDER else 0,
                SAFETY_MARGIN
            )
            crop_img = img.crop(tight_crop).convert("RGBA")
            crop_img.putalpha(mask)

            piece = crop_img
            if ADD_BORDER and BORDER_PX>0:
                border_mask = mask.copy()
                for _ in range(BORDER_PX//2+1):
                    border_mask = border_mask.filter(ImageFilter.MaxFilter(3))
                edge = Image.new("RGBA", border_mask.size, (255,255,255,255))
                edge.putalpha(border_mask)
                final = Image.new("RGBA", edge.size, (0,0,0,0))
                final.alpha_composite(edge, (0,0))
                final.alpha_composite(piece, (0,0))
                piece = final

            if ALLOW_ROTATION and abs(MAX_ROTATE_DEG) > 0:
                deg = random.uniform(-MAX_ROTATE_DEG, MAX_ROTATE_DEG)
                piece = piece.rotate(deg, expand=True, resample=Image.BICUBIC)

            if ADD_PIECE_SHADOW and PIECE_SHADOW_ALPHA>0 and PIECE_SHADOW_BLUR>0:
                alpha = piece.split()[-1]
                shadow_color = Image.new("RGBA", piece.size, (0,0,0,PIECE_SHADOW_ALPHA))
                shadow = Image.new("RGBA", piece.size, (0,0,0,0))
                shadow.paste(shadow_color, (0,0), mask=alpha)
                shadow = shadow.filter(ImageFilter.GaussianBlur(PIECE_SHADOW_BLUR))
                pieces.append({"img": piece, "shadow": shadow, "shadow_offset": PIECE_SHADOW_OFFSET})
            else:
                pieces.append({"img": piece, "shadow": None, "shadow_offset": (0,0)})
    return pieces

# ===================== output =====================
def scatter_nonoverlap(pieces: List[dict], rows:int, cols:int, out_path:str):
    max_w = max(p["img"].size[0] for p in pieces) + abs(PIECE_SHADOW_OFFSET[0])
    max_h = max(p["img"].size[1] for p in pieces) + abs(PIECE_SHADOW_OFFSET[1])
    cell_w = max_w + CELL_INNER_JITTER*2
    cell_h = max_h + CELL_INNER_JITTER*2

    canvas_w = (cell_w*cols) + (CELL_GAP*(cols-1)) + CANVAS_PADDING*2
    canvas_h = (cell_h*rows) + (CELL_GAP*(rows-1)) + CANVAS_PADDING*2
    canvas = Image.new("RGBA", (canvas_w, canvas_h), BACKGROUND_RGBA)

    # random.shuffle(pieces)
    idx=0
    for r in range(rows):
        for c in range(cols):
            if idx >= len(pieces): break
            p = pieces[idx]; idx+=1
            cell_x = CANVAS_PADDING + c*(cell_w + CELL_GAP)
            cell_y = CANVAS_PADDING + r*(cell_h + CELL_GAP)
            free_x = max(0, cell_w - p["img"].size[0])
            free_y = max(0, cell_h - p["img"].size[1])
            base_x = cell_x + free_x//2
            base_y = cell_y + free_y//2
            jx = random.randint(-CELL_INNER_JITTER, CELL_INNER_JITTER)
            jy = random.randint(-CELL_INNER_JITTER, CELL_INNER_JITTER)
            x = min(max(base_x + jx, cell_x), cell_x + cell_w - p["img"].size[0])
            y = min(max(base_y + jy, cell_y), cell_y + cell_h - p["img"].size[1])

            if p["shadow"] is not None:
                sx,sy = p["shadow_offset"]
                canvas.alpha_composite(p["shadow"], (x+sx, y+sy))
            canvas.alpha_composite(p["img"], (x, y))

    out = canvas.convert("RGB"); out.save(out_path, quality=95)
    print(f"✅ 散落拼图已保存：{out_path}  尺寸：{out.size}")

def grid_neat(pieces: List[dict], rows:int, cols:int, out_path:str):
    max_w = max(p["img"].size[0] for p in pieces) + abs(PIECE_SHADOW_OFFSET[0])
    max_h = max(p["img"].size[1] for p in pieces) + abs(PIECE_SHADOW_OFFSET[1])
    canvas_w = cols*max_w + (cols-1)*GRID_GAP + 2*GRID_CANVAS_PAD
    canvas_h = rows*max_h + (rows-1)*GRID_GAP + 2*GRID_CANVAS_PAD
    canvas = Image.new("RGBA", (canvas_w, canvas_h), BACKGROUND_RGBA)

    idx=0
    for r in range(rows):
        for c in range(cols):
            if idx >= len(pieces): break
            p = pieces[idx]; idx+=1
            x = GRID_CANVAS_PAD + c*(max_w + GRID_GAP)
            y = GRID_CANVAS_PAD + r*(max_h + GRID_GAP)
            if p["shadow"] is not None:
                sx,sy = p["shadow_offset"]
                canvas.alpha_composite(p["shadow"], (x+sx, y+sy))
            canvas.alpha_composite(p["img"], (x, y))

    out = canvas.convert("RGB"); out.save(out_path, quality=95)
    print(f"✅ 整齐矩阵拼图已保存：{out_path}  尺寸：{out.size}")

# ===================== pipeline =====================
def main():
    if RANDOM_SEED is not None:
        random.seed(RANDOM_SEED)

    base = Image.open(IMAGE_PATH).convert("RGB")

    img_with_text = draw_text_block(base, TEXT_TO_DRAW)
    p_text = os.path.join(OUT_DIR, "step2_with_text.png")
    img_with_text.save(p_text, quality=95)
    print(f"✅ 写字完成：{p_text}  尺寸：{img_with_text.size}")

    pieces = generate_jigsaw_pieces(
        img_with_text, GRID_ROWS, GRID_COLS,
        knob_r=KNOB_RADIUS, neck_frac=NECK_FRAC, edge_round=EDGE_ROUNDING
    )

    p_scatter = os.path.join(OUT_DIR, "step3_jigsaw_scatter.png")
    scatter_nonoverlap(pieces, GRID_ROWS, GRID_COLS, p_scatter)

    p_grid = os.path.join(OUT_DIR, "step4_jigsaw_grid.png")
    grid_neat(pieces, GRID_ROWS, GRID_COLS, p_grid)

if __name__ == "__main__":
    main()
