import ShapePrimitive
import ShapeOps
import plot_utils
import numpy as np
import random
from shapely.geometry import Point
from shapely.affinity import scale, rotate, translate
from shapely.prepared import prep


class XShapeGen:
    def __init__(self, config = None):
        self.config = config
        np.random.seed(self.config['COMPLEX_SHAPE_GEN']['random_seed_id'])
        random.seed(self.config['COMPLEX_SHAPE_GEN']['random_seed_id'])
        self.ps = ShapePrimitive.PrimitiveShapes(canvas_size=self.config['COMPLEX_SHAPE_GEN']['canvas_size'])

        self.get_shape_operations()
        self.get_shape_primitives()

        self.binary_op = ShapeOps.BinaryShapeOp()
        self.unary_op = ShapeOps.UnaryShapeOp()
        self.unary_transform = ShapeOps.UnaryShapeTransform()

    def get_shape_primitives(self):
        self.shape_primitives = [
            self.ps.circle(),
            self.ps.square(),
            self.ps.rectangle(),
            self.ps.triangle(),
            self.ps.ellipse(),
            self.ps.diamond(),
            self.ps.sector(),
            self.ps.pentagon(),
        ]
    
    def get_shape_operations(self):
        shape_ops = list()

        shape_ops.append({'op_name': 'scale', 'arity': 'unary' })
        shape_ops.append({'op_name': 'translate', 'arity': 'unary' })
        shape_ops.append({'op_name': 'rotate', 'arity': 'unary' })
        shape_ops.append({'op_name': 'union', 'arity': 'binary' })
        shape_ops.append({'op_name': 'intersect', 'arity': 'binary' })
        shape_ops.append({'op_name': 'subtract', 'arity': 'binary' })
        shape_ops.append({'op_name': 'xor', 'arity': 'binary' })
        shape_ops.append({'op_name': 'convex_hull', 'arity': 'binary'})

        self.shape_ops = shape_ops

    def apply_unary_op(self, shape, op_name):
        if op_name == 'scale':
            return self.unary_transform.run_transform(shape, 'scale', 
                                                      scale_x=np.random.uniform(*self.config['COMPLEX_SHAPE_GEN']['scale_range']),
                                                      scale_y=np.random.uniform(*self.config['COMPLEX_SHAPE_GEN']['scale_range']))
        elif op_name == 'translate':
            return self.unary_transform.run_transform(shape, 'translate', 
                                                      x=np.random.uniform(*self.config['COMPLEX_SHAPE_GEN']['translate_range']),
                                                      y=np.random.uniform(*self.config['COMPLEX_SHAPE_GEN']['translate_range']))
        elif op_name == 'rotate':
            return self.unary_transform.run_transform(shape, 'rotate', angle=np.random.uniform(*self.config['COMPLEX_SHAPE_GEN']['rotate_range']))
        
    def apply_binary_op(self, shape1, shape2, op_name):
        if op_name == 'union':
            return self.binary_op.run_op(shape1, shape2, 'union')
        elif op_name == 'intersect':
            return self.binary_op.run_op(shape1, shape2, 'intersect')
        elif op_name == 'subtract':
            return self.binary_op.run_op(shape1, shape2, 'subtract')
        elif op_name == 'xor':
            return self.binary_op.run_op(shape1, shape2, 'xor')
        elif op_name == 'convex_hull':
            return self.binary_op.run_op(shape1, shape2, 'convex_hull')

    def create_shapes_yh_impl(self):
        # Example shapes
        for shape_id in range(self.config['COMPLEX_SHAPE_GEN']['shape_num']):
            shape = None
            shape_corpus = self.shape_primitives.copy()
            for op_id in range(self.config['COMPLEX_SHAPE_GEN']['op_num']):
                op = np.random.choice(self.shape_ops)
                if shape is None:
                    shape = np.random.choice(shape_corpus)
                if op['arity'] == 'unary':
                    shape = self.apply_unary_op(shape, op['op_name'])
                elif op['arity'] == 'binary':
                    other_shape = np.random.choice(shape_corpus)
                    shape = self.apply_binary_op(shape, other_shape, op['op_name'])

                if not shape.is_valid:
                    break

                shape_corpus.append(shape)
            
    def create_shapes(self, binary_op_num, min_steps, max_steps):
        """
        Build a connected complex shape with exactly `binary_op_num` binary ops.
        All composition is done in a unit disk, then mapped back to the canvas.

        Assumes: min_steps < binary_op_num < max_steps
        """
        # --- Hyperparams that control complexity & anti-degeneracy ---
        min_chg, max_chg = 0.12, 0.55     # acceptable symmetric-diff change vs current
        min_ov,  max_ov  = 0.25, 0.70     # acceptable overlap ratio area(current ∩ other)/area(current)
        partner_area_frac_range = (0.25, 0.55)  # partner primitive area fraction w.r.t current
        binary_ops_menu = ['union', 'intersect', 'subtract', 'xor', 'convex_hull']
        unit_disk = Point(0.0, 0.0).buffer(1.0, resolution=128)

        # ---------------- helpers ----------------
        def _clean(g):
            try:
                return g.buffer(0)
            except Exception:
                return g

        def _is_valid_nonempty(g):
            return (g is not None) and (not g.is_empty) and g.is_valid

        def _is_connected_poly(g):
            g = _clean(g)
            if not _is_valid_nonempty(g):
                return False
            return g.geom_type == 'Polygon'  # disallow MultiPolygon and non-polygonal

        def _connectedify(candidate, anchor):
            cand = _clean(candidate)
            if not _is_valid_nonempty(cand):
                return cand
            if cand.geom_type == 'Polygon':
                return cand
            if cand.geom_type == 'MultiPolygon':
                parts = list(cand.geoms)
                touching = [p for p in parts if _clean(p).intersects(_clean(anchor))]
                pool = touching if touching else parts
                return max(pool, key=lambda g: g.area)
            return cand

        def _to_unit(g):
            """Normalize arbitrary geometry into unit disk centered at origin, then clip to unit disk."""
            g = _clean(g)
            # center at its centroid
            c = g.centroid
            g = translate(g, xoff=-c.x, yoff=-c.y)
            # scale so that max(|x|,|y|) of bounds becomes <= 1, then clip to unit disk
            minx, miny, maxx, maxy = g.bounds
            s = max(abs(minx), abs(maxx), abs(miny), abs(maxy))
            if s == 0:
                s = 1.0
            g = scale(g, xfact=1.0/s, yfact=1.0/s, origin=(0.0, 0.0))
            g = _clean(g.intersection(unit_disk))
            return g

        def _from_unit(g):
            """Map unit-disk geometry back to canvas coordinates."""
            S = float(self.config['COMPLEX_SHAPE_GEN']['canvas_size'])
            g = scale(g, xfact=S/2.0, yfact=S/2.0, origin=(0.0, 0.0))
            g = translate(g, xoff=S/2.0, yoff=S/2.0)
            return _clean(g)

        def _rand_primitive_unit():
            # sample from prepared primitives, then normalize to unit
            shp = np.random.choice(self.shape_primitives)
            shp = _to_unit(shp)
            # light jitter (unit space)
            for _ in range(np.random.randint(0, 3)):
                t = np.random.choice(['scale', 'rotate', 'translate'])
                if t == 'scale':
                    shp = scale(shp,
                                xfact=np.random.uniform(0.8, 1.2),
                                yfact=np.random.uniform(0.8, 1.2),
                                origin='center')
                elif t == 'rotate':
                    shp = rotate(shp, angle=np.random.uniform(-45, 45), origin='center')
                else:
                    dx = np.random.uniform(-0.1, 0.1)
                    dy = np.random.uniform(-0.1, 0.1)
                    shp = translate(shp, xoff=dx, yoff=dy)
                shp = _clean(shp.intersection(unit_disk))
            return _clean(shp)

        def _area(g): 
            try: 
                return float(g.area)
            except Exception:
                return 0.0

        def _symmetric_diff_ratio(a, b):
            """ |A Δ B| / |A|  (relative change magnitude wrt current A) """
            denom = _area(a)
            if denom <= 1e-12:
                return 1.0
            return _area(_clean(a.symmetric_difference(b))) / denom

        def _overlap_ratio(a, b):
            """ |A ∩ B| / |A|  (how much partner overlaps current) """
            denom = _area(a)
            if denom <= 1e-12:
                return 0.0
            return _area(_clean(a.intersection(b))) / denom

        def _almost_equal(a, b, tol=0.95):
            """Prevent 'snap back to primitive': large IoU indicates candidate ≈ partner."""
            inter = _area(_clean(a.intersection(b)))
            uni = _area(_clean(a.union(b)))
            if uni <= 1e-12:
                return True
            iou = inter / uni
            return iou >= tol

        def _warp_partner_for_current(current):
            """
            Create a 'partner' from a primitive that:
            - has target area fraction of current,
            - overlaps current within [min_ov, max_ov],
            - lives strictly inside the unit disk.
            """
            base = _rand_primitive_unit()
            # scale partner area to target fraction of current
            curA = max(_area(current), 1e-9)
            tgt_frac = np.random.uniform(*partner_area_frac_range)
            tgtA = curA * tgt_frac
            baseA = max(_area(base), 1e-9)
            s = np.sqrt(tgtA / baseA)
            partner = scale(base, xfact=s, yfact=s, origin='center')
            partner = _clean(partner.intersection(unit_disk))

            # place to achieve overlap within range: nudge towards/around current centroid
            for _ in range(12):
                ang = np.random.uniform(0, 2*np.pi)
                rad = np.random.uniform(0.0, 0.6)   # stay near the center to encourage overlap
                dx, dy = rad*np.cos(ang), rad*np.sin(ang)
                candidate = translate(partner, xoff=dx, yoff=dy)
                candidate = rotate(candidate, angle=np.random.uniform(-180, 180), origin='center')
                candidate = _clean(candidate.intersection(unit_disk))
                ov = _overlap_ratio(current, candidate)
                if min_ov <= ov <= max_ov:
                    return candidate
            # fallback: weakly align to guarantee some overlap
            return _clean(partner.buffer(0.0).intersection(unit_disk))

        # ---------- seed in unit space ----------
        current = _to_unit(np.random.choice(self.shape_primitives))
        for _ in range(10):
            if _is_connected_poly(current):
                break
            current = _to_unit(np.random.choice(self.shape_primitives))
        if not _is_connected_poly(current):
            current = _clean(_to_unit(self.ps.square()).union(_to_unit(self.ps.circle()))).convex_hull

        # total steps in [max(binary_op_num, min_steps), max_steps]
        steps = int(np.random.randint(max(binary_op_num, min_steps), max_steps + 1))
        plan = ['binary'] * binary_op_num + ['unary'] * (steps - binary_op_num)
        np.random.shuffle(plan)

        # ---------- execute plan with constraints ----------
        executed_binary = 0
        for typ in plan:
            updated = None
            if typ == 'unary':
                # subtle changes that preserve scale & connectivity
                for _try in range(8):
                    t = np.random.choice(['scale', 'rotate', 'translate'])
                    if t == 'scale':
                        cand = scale(current,
                                    xfact=np.random.uniform(0.9, 1.1),
                                    yfact=np.random.uniform(0.9, 1.1),
                                    origin='center')
                    elif t == 'rotate':
                        cand = rotate(current, angle=np.random.uniform(-30, 30), origin='center')
                    else:
                        cand = translate(current,
                                        xoff=np.random.uniform(-0.08, 0.08),
                                        yoff=np.random.uniform(-0.08, 0.08))
                    cand = _clean(cand.intersection(unit_disk))
                    if _is_valid_nonempty(cand) and _is_connected_poly(cand) and not cand.equals(current):
                        updated = cand
                        break
                if updated is None:
                    updated = _clean(current.convex_hull)
            else:
                # must execute a *useful* binary op that changes shape without degeneracy
                for _try in range(14):
                    partner = _warp_partner_for_current(current)
                    op_name = np.random.choice(binary_ops_menu)

                    # Avoid obviously bad combos:
                    #  - if partner barely overlaps, don't 'intersect'
                    #  - if partner almost engulfs, don't 'subtract' fully, cap with XOR/union
                    ov = _overlap_ratio(current, partner)
                    if op_name == 'intersect' and ov < min_ov:
                        op_name = 'union'
                    if op_name == 'subtract' and ov < 0.15:
                        op_name = 'xor'

                    candidate = self.apply_binary_op(current, partner, op_name)
                    candidate = _connectedify(candidate, current)
                    candidate = _clean(candidate.intersection(unit_disk))

                    if not (_is_valid_nonempty(candidate) and _is_connected_poly(candidate) and not candidate.equals(current)):
                        continue

                    # Change magnitude & anti-degeneracy checks
                    chg = _symmetric_diff_ratio(current, candidate)
                    if not (min_chg <= chg <= max_chg):
                        continue
                    # prevent "collapsing to partner primitive"
                    if _almost_equal(candidate, partner, tol=0.95):
                        continue
                    # avoid tiny holes/slivers after subtract/xor: ensure min area fraction retained
                    if op_name in ['subtract', 'xor']:
                        if _area(candidate) < 0.6 * _area(current):
                            # too destructive → skip
                            continue

                    updated = candidate
                    executed_binary += 1
                    break

                if updated is None:
                    # deterministic safe union with a guided partner to ensure progress
                    partner = _warp_partner_for_current(current)
                    candidate = self.apply_binary_op(current, partner, 'union')
                    candidate = _connectedify(candidate, current)
                    candidate = _clean(candidate.intersection(unit_disk))
                    if _is_valid_nonempty(candidate) and _is_connected_poly(candidate) and not candidate.equals(current):
                        updated = candidate
                        executed_binary += 1
                    else:
                        # postpone this binary and try later
                        plan.append('binary')
                        continue

            current = updated if updated is not None else current

        # Ensure exact binary count (rarely needed due to fallbacks)
        while executed_binary < binary_op_num:
            partner = _warp_partner_for_current(current)
            cand = self.apply_binary_op(current, partner, 'union')
            cand = _connectedify(cand, current)
            cand = _clean(cand.intersection(unit_disk))
            if _is_valid_nonempty(cand) and _is_connected_poly(cand) and not cand.equals(current):
                current = cand
                executed_binary += 1

        # Map back to canvas coordinates
        return _from_unit(current)

    def create_shape_one_depth(self, depth, gen_num):
        assert depth >= 0
        assert gen_num >= 1

        shape_dict = dict()
        shape_dict['depth'] = depth
        shape_dict['shapes'] = list()

        for shape_id in range(gen_num):
            min_step_num = depth
            max_step_num = depth + 4
            binary_op_num = depth
            shape = self.create_shapes(binary_op_num, min_step_num, max_step_num)
            assert not shape.is_empty and shape.is_valid and shape.geom_type == 'Polygon', "Generated shape must be a valid non-empty Polygon"

            # plot_utils.plot_one_shape(shape, save_path='gened_shape.png')

            shape_info = self.shape_to_masks_and_pose(shape)

            shape_dict['shapes'].append({'shape': shape,
                                         'shape_info': shape_info})

        return shape_dict        

    def to_unit_and_pose(self, geom_canvas):
        """
        Decompose a canvas-space geometry into a unit-disk geometry and a pose.
        Pose: translation (tx, ty) and isotropic scale s such that:
              geom_canvas ≈ translate(scale(unit_geom, s, s, origin=(0,0)), tx, ty)
        Also returns homogeneous pose_matrix and pose vectors (canvas & normalized).
        """
        S = float(self.config['COMPLEX_SHAPE_GEN']['canvas_size'])
        g = geom_canvas

        # translation = centroid
        cx, cy = list(g.centroid.coords)[0]
        tx, ty = cx, cy

        # center at origin
        g0 = translate(g, xoff=-tx, yoff=-ty)

        # isotropic scale to enclose in unit square then clip to unit disk
        minx, miny, maxx, maxy = g0.bounds
        s = max(abs(minx), abs(maxx), abs(miny), abs(maxy))
        if s <= 0:
            s = 1.0

        unit_geom_raw = scale(g0, xfact=1.0/s, yfact=1.0/s, origin=(0.0, 0.0))
        unit_disk = Point(0.0, 0.0).buffer(1.0, resolution=256)
        unit_geom = unit_geom_raw.intersection(unit_disk)

        pose_matrix = np.array([[s, 0.0, tx],
                                [0.0, s, ty],
                                [0.0, 0.0, 1.0]], dtype=float)
        pose_vector = np.array([tx, ty, s], dtype=float)

        # normalized pose (unit-disk frame)
        half = S / 2.0
        tx_n = (tx - half) / half
        ty_n = (ty - half) / half
        s_n  = s / half
        pose_vector_norm = np.array([tx_n, ty_n, s_n], dtype=float)

        pose = dict(tx=tx, ty=ty, s=s, tx_norm=tx_n, ty_norm=ty_n, s_norm=s_n)

        return unit_geom, pose, pose_matrix, pose_vector, pose_vector_norm

    # ===== Cartesian raster (optional) =====
    def rasterize_unit_cart(self, unit_geom, res=300, ss=1):
        """
        Rasterize unit-disk geometry on a Cartesian grid in [-1,1]^2 to a (res x res) mask.
        ss = supersampling factor per pixel side (1=no SSAA, 2=4x, 3=9x ...)
        """
        mask = np.zeros((res, res), dtype=np.float32)
        if unit_geom.is_empty:
            return mask.astype(np.uint8)

        gu = prep(unit_geom)
        # pixel centers in [-1,1]
        xs = (np.arange(res) + 0.5) / res * 2.0 - 1.0
        ys = (np.arange(res) + 0.5) / res * 2.0 - 1.0

        if ss <= 1:
            for j, y in enumerate(ys[::-1]):
                row_pts = [Point(x, y) for x in xs]
                mask[res - 1 - j, :] = [gu.contains(p) or gu.touches(p) for p in row_pts]
        else:
            # SSAA: average ss*ss sub-samples per pixel
            offs = (np.arange(ss) + 0.5) / (ss * res) * 2.0 - 1.0 / res
            for j in range(res):
                y0 = ys[::-1][j]
                vals = []
                for i in range(res):
                    x0 = xs[i]
                    cnt = 0
                    for oy in offs:
                        for ox in offs:
                            p = Point(x0 + ox, y0 + oy)
                            if gu.contains(p) or gu.touches(p):
                                cnt += 1
                    vals.append(cnt / (ss * ss))
                mask[res - 1 - j, :] = vals

        return (mask >= 0.5).astype(np.uint8)

    # ===== Polar raster (Zernike-ready) =====
    def rasterize_unit_polar(self, unit_geom, r_bins=300, theta_bins=300, ss=1):
        """
        Rasterize unit-disk geometry on a polar grid:
          r in [0,1], theta in [0, 2π)
        Returns a binary mask of shape (r_bins, theta_bins), where mask[ri, ti] corresponds to
        the sample at r=(ri+0.5)/r_bins, theta=(ti+0.5)/theta_bins * 2π.
        This mask can be used directly with Zernike encoders that integrate over r,theta.
        Optional SSAA (ss) jitter-samples around each (r,theta) center in Cartesian domain.
        """
        M = np.zeros((r_bins, theta_bins), dtype=np.float32)
        if unit_geom.is_empty:
            return M.astype(np.uint8)

        gu = prep(unit_geom)
        rs = (np.arange(r_bins) + 0.5) / r_bins
        thetas = (np.arange(theta_bins) + 0.5) / theta_bins * (2.0 * np.pi)

        if ss <= 1:
            for ri, r in enumerate(rs):
                for ti, th in enumerate(thetas):
                    x = r * np.cos(th)
                    y = r * np.sin(th)
                    p = Point(x, y)
                    M[ri, ti] = 1.0 if (gu.contains(p) or gu.touches(p)) else 0.0
        else:
            # SSAA in Cartesian around the polar center (tiny ring/angle jitter)
            # jitter scaled to approx pixel size in polar space
            dr = 1.0 / r_bins
            dth = 2.0 * np.pi / theta_bins
            # symmetric jitter set
            jr = (np.arange(ss) + 0.5) / (ss) - 0.5   # in [-0.5, 0.5]
            jth = (np.arange(ss) + 0.5) / (ss) - 0.5
            for ri, r in enumerate(rs):
                for ti, th in enumerate(thetas):
                    cnt = 0
                    for a in jr:
                        for b in jth:
                            rj = max(0.0, min(1.0, r + a * dr))
                            thj = th + b * dth
                            x = rj * np.cos(thj)
                            y = rj * np.sin(thj)
                            if gu.contains(Point(x, y)) or gu.touches(Point(x, y)):
                                cnt += 1
                    M[ri, ti] = cnt / (ss * ss)

        return (M >= 0.5).astype(np.uint8)

    # ===== All-in-one convenience =====
    def shape_to_masks_and_pose(self, geom_canvas, cart_res=300, r_bins=300, theta_bins=300, ss=1):
        """
        Convert a canvas-space geometry into:
          - unit_geom (Shapely)
          - Cartesian mask [cart_res, cart_res] in {0,1}
          - Polar mask [r_bins, theta_bins] in {0,1}  (Zernike-ready)
          - pose dict/matrix/vectors
        """
        unit_geom, pose, pose_mat, pose_vec, pose_vec_norm = self.to_unit_and_pose(geom_canvas)
        cart_mask = self.rasterize_unit_cart(unit_geom, res=cart_res, ss=ss)
        polar_mask = self.rasterize_unit_polar(unit_geom, r_bins=r_bins, theta_bins=theta_bins, ss=ss)

        return {
            "unit_geom": unit_geom,
            "cart_mask": cart_mask,
            "polar_mask": polar_mask,          # ← feed this to Zernike encoder
            "pose": pose,
            "pose_matrix": pose_mat,
            "pose_vector": pose_vec,
            "pose_vector_norm": pose_vec_norm,
        }

