# mmr_gym/motifs/dotgrid2d.py
import random
from PIL import Image, ImageDraw
from ..base import Motif
from ..schema import MotifSpec
from ..config import SUPERSAMPLE, SS_CELL, COLORS
from ..registry import register_motif
from .helpers import _down_on_background

@register_motif
class DotGrid2DMotif(Motif):
    """
    rows × cols grid of dots.

    Conventions:
      - We map `spec.count` → number of columns, so `Sequence` tasks (e.g., count_arith)
        can drive horizontal cardinality directly.
      - Rows live in `spec.extra["rows"]`.


    Extras (spec.extra):
      - rows  (int): number of rows
      - scale (float): global size multiplier that affects the margin
    """
    name = "dotgrid2d"
    attr_ranges = {"cols": (3, 9), "rows": (2, 7)}

    # --- sampling ---
    def sample_spec(self, rng):
        seed = rng.randint(0, 2**31 - 1)
        rows = rng.randint(*self.attr_ranges["rows"])
        cols = rng.randint(*self.attr_ranges["cols"])
        extra = {"scale": rng.uniform(0.9, 1.1), "rows": rows}
        return MotifSpec(
            self.name,
            seed,
            rng.randrange(len(COLORS)),
            count=cols,       # <- columns under 'count'
            size=1.0,
            extra=extra
        )

    # --- normalization ---
    def clamp_spec(self, spec):
        ex = dict(spec.extra or {})
        ex.pop("mode", None)  # ignore any generic symmetry-mode flag

        rmin, rmax = self.attr_ranges["rows"]
        cmin, cmax = self.attr_ranges["cols"]

        rows = max(int(rmin), min(int(rmax), int(ex.get("rows", 4))))
        cols = max(int(cmin), min(int(cmax), int(getattr(spec, "count", 5))))

        scale = float(ex.get("scale", 1.0))
        scale = max(0.75, min(1.30, scale))

        ex.update({"rows": rows, "scale": scale})
        return spec.clone(count=cols, size=1.0, extra=ex)

    # --- rendering ---
    def render(self, spec):
        s = self.clamp_spec(spec)
        rows = int(s.extra["rows"])
        cols = max(1, int(s.count))

        img = Image.new("RGBA", (SS_CELL, SS_CELL), (255, 255, 255, 0))
        d = ImageDraw.Draw(img)
        color = COLORS[s.color_idx]
        scale = float(s.extra["scale"])

        M = int(SS_CELL * 0.12 * scale)         # margin
        usable_w = SS_CELL - 2 * M
        usable_h = SS_CELL - 2 * M
        cell_w = usable_w / cols
        cell_h = usable_h / max(1, rows)

        # Dot radius ≈ 30% of the smaller cell dimension; supersampled clamp
        r = max(2 * SUPERSAMPLE, int(min(cell_w, cell_h) * 0.30))

        xs = [int(M + (i + 0.5) * cell_w) for i in range(cols)]
        ys = [int(M + (j + 0.5) * cell_h) for j in range(rows)]

        ow = max(1, SUPERSAMPLE)
        for y in ys:
            for x in xs:
                d.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="black", width=ow)

        return _down_on_background(img)
