# ---------------------------------------------------------------
# DoG-multiscale structure prior as a "corruption" for OT-Bridge
# x1: clean/target image in [-1,1], BCHW, RGB

# ---------------------------------------------------------------
import re
import numpy as np
import torch
import cv2
from skimage.morphology import skeletonize

# ---------- helpers: tensor<->uint8 ----------
def _to_u8(x):
    # x: [-1,1] float tensor -> uint8 numpy
    x = (x.clamp(-1, 1) + 1.0) * 127.5
    return x.round().byte().detach().cpu().numpy()

def _from_u8(xu8, device):
    # uint8 numpy -> [-1,1] float tensor
    x = torch.from_numpy(xu8.astype(np.float32) / 127.5 - 1.0)
    return x.to(device)

# ---------- image ops ----------
def _multiscale_dog(gray_u8, sigma_pairs, invert_for_dark=True):
    x = gray_u8.astype(np.float32)
    if invert_for_dark:
        x = 255.0 - x
    acc = None
    for s1, s2 in sigma_pairs:
        g1 = cv2.GaussianBlur(x, (0, 0), s1)
        g2 = cv2.GaussianBlur(x, (0, 0), s2)
        dog = np.maximum(g1 - g2, 0.0)
        acc = dog if acc is None else np.maximum(acc, dog)
    return cv2.normalize(acc, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def _adaptive_binarize(img_u8, block_size=35, C=-4):
    if block_size % 2 == 0:
        block_size += 1
    if block_size < 3:
        block_size = 3
    return cv2.adaptiveThreshold(
        img_u8, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,
        block_size, C
    )

def _percentile_binarize(img_u8, q=0.60):
    t = np.quantile(img_u8.reshape(-1), q)  
    return (img_u8 >= t).astype(np.uint8) * 255

def _postprocess(binary, open_ksize=3, close_ksize=3):
    out = binary
    if open_ksize > 0:
        k1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_ksize, open_ksize))
        out = cv2.morphologyEx(out, cv2.MORPH_OPEN, k1, iterations=1)
    if close_ksize > 0:
        k2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_ksize, close_ksize))
        out = cv2.morphologyEx(out, cv2.MORPH_CLOSE, k2, iterations=1)
    return out

def _canny_edge(gray_u8, low_thresh=50, high_thresh=150):
    edges = cv2.Canny(gray_u8, low_thresh, high_thresh)
    return edges

def _thin_edges(binary, method='skeleton'):

    if method == 'skeleton':
        binary_bool = binary > 127
        skeleton = skeletonize(binary_bool)
        return (skeleton.astype(np.uint8) * 255)
    elif method == 'morph':
        kernel = np.ones((3, 3), np.uint8)
        thinned = cv2.morphologyEx(binary, cv2.MORPH_GRADIENT, kernel)
        return thinned
    return binary

def _fine_dog(gray_u8, sigma_pairs, invert_for_dark=True):
    x = gray_u8.astype(np.float32)
    if invert_for_dark:
        x = 255.0 - x
    acc = None
    for s1, s2 in sigma_pairs:
        g1 = cv2.GaussianBlur(x, (0, 0), s1)
        g2 = cv2.GaussianBlur(x, (0, 0), s2)
        dog = np.maximum(g1 - g2, 0.0)
        acc = dog if acc is None else np.maximum(acc, dog)
    acc = cv2.normalize(acc, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return acc

# ---------- parser ----------
def _parse_corrupt_args(corrupt_type):

    cfg = dict(
        mode="adapt",           # 'canny' | 'fine_dog' | 'adapt' | 'percentile' | 'raw'
        ms=[(0.3,0.8),(0.5,1.2),(0.8,2.0)],  
        block=35, C=-8, q=0.60,  
        post_open=1, post_close=1,  
        invert=1,
        thin='skeleton',        # 'skeleton' | 'morph' | None
        canny_low=30,         
        canny_high=100,      
    )
    if ":" not in corrupt_type:
        return cfg
    _, argstr = corrupt_type.split(":", 1)
    for kv in argstr.split(";"):
        if not kv.strip():
            continue
        if "=" not in kv:
            continue
        k, v = kv.split("=", 1)
        k = k.strip().lower()
        v = v.strip()
        if k == "mode":
            cfg["mode"] = v
        elif k == "ms":
            # "a,b|c,d|e,f" -> [(a,b),(c,d),(e,f)]
            pairs = []
            for token in v.split("|"):
                a, b = token.split(",")
                pairs.append((float(a), float(b)))
            cfg["ms"] = pairs
        elif k == "block":
            cfg["block"] = int(v)
        elif k == "c":
            cfg["C"] = int(v)
        elif k == "q":
            cfg["q"] = float(v)
        elif k == "post":
            a, b = v.split(",")
            cfg["post_open"] = int(a)
            cfg["post_close"] = int(b)
        elif k == "invert":
            cfg["invert"] = int(v)
        elif k == "thin":
            cfg["thin"] = None if v.lower() == "none" else v
        elif k == "low":
            cfg["canny_low"] = int(v)
        elif k == "high":
            cfg["canny_high"] = int(v)
    return cfg

# ---------- builder ----------
def build_dogms(opt, log, corrupt_type):
    cfg = _parse_corrupt_args(corrupt_type)
    log.info(f"[dogms] mode={cfg['mode']} thin={cfg['thin']}")
    if cfg['mode'] in ['fine_dog', 'adapt', 'percentile']:
        log.info(f"[dogms] ms={cfg['ms']} post=({cfg['post_open']},{cfg['post_close']}) invert={cfg['invert']}")
    if cfg['mode'] == 'canny':
        log.info(f"[dogms] canny_low={cfg['canny_low']} canny_high={cfg['canny_high']}")

    def method(x1, *args, **kwargs):
        """
        x1: torch.FloatTensor, [B,3,H,W], range [-1,1], RGB
        return x0: same shape/range
        """
        device = x1.device
        x1_u8 = _to_u8(x1)                # [B,3,H,W] uint8
        outs = []
        for b in range(x1_u8.shape[0]):
            img = np.transpose(x1_u8[b], (1,2,0))                   # HWC
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            
            if cfg["mode"] == "canny":
                out = _canny_edge(gray, cfg["canny_low"], cfg["canny_high"])
            elif cfg["mode"] == "fine_dog":
                dog = _fine_dog(gray, cfg["ms"], invert_for_dark=bool(cfg["invert"]))
                binm = _adaptive_binarize(dog, cfg["block"], cfg["C"])
                binm = _postprocess(binm, cfg["post_open"], cfg["post_close"])
                out = binm
            elif cfg["mode"] == "adapt":
                dog  = _multiscale_dog(gray, cfg["ms"], invert_for_dark=bool(cfg["invert"]))
                binm = _adaptive_binarize(dog, cfg["block"], cfg["C"])
                binm = _postprocess(binm, cfg["post_open"], cfg["post_close"])
                out = binm
            elif cfg["mode"] == "percentile":
                dog  = _multiscale_dog(gray, cfg["ms"], invert_for_dark=bool(cfg["invert"]))
                binm = _percentile_binarize(dog, cfg["q"])
                binm = _postprocess(binm, cfg["post_open"], cfg["post_close"])
                out = binm
            else:
                dog  = _multiscale_dog(gray, cfg["ms"], invert_for_dark=bool(cfg["invert"]))
                out = dog
            
            if cfg["thin"] is not None:
                out = _thin_edges(out, cfg["thin"])

            out3 = np.stack([out, out, out], axis=2)                # HWC
            outs.append(np.transpose(out3, (2,0,1)))                # CHW

        x0_u8 = np.stack(outs, 0)                                   # B,C,H,W uint8
        x0 = _from_u8(x0_u8, device)                                # [-1,1]
        return x0

    return method