# Copyright (c) 2023 Michael Hodel, licensed under MIT
"""Adapted from https://github.com/michaelhodel/re-arc/blob/main/dsl.py"""

import inspect
import math
import random
import sys
import typing
from typing import Tuple, Any, Container, Callable, FrozenSet

from src.datasets.task_gen import utils, types_
from src.datasets.task_gen.re_arc_generators import GENERATORS_SRC_CODE
from src.datasets.task_gen.types_ import (
    Boolean,
    Integer,
    IntegerTuple,
    Numerical,
    IntegerSet,
    Grid,
    Object,
    Objects,
    Indices,
    IndicesSet,
    Patch,
    Element,
    Piece,
    T,
    T2,
    T3,
)

## Random


def rand_uniforminput():
    num_colors = 5
    num_rows, num_cols = random.randint(1, 30), random.randint(1, 30)
    return tuple(tuple(random.randint(0, num_colors) for col in range(num_cols)) for row in range(num_rows))


def rand_sample(container: Container[T], num: Integer):
    num = utils.ensure_bounds(num, low=0, high=len(container))
    return random.sample(container, num)


def rand_choice(container: Container[T]):
    return random.choice(tuple(container))


def rand_randint(a: Integer, b: Integer):
    return random.randint(a, b)


def rand_shuffle(container: Container[T]):
    container_list = list(container)
    random.shuffle(container_list)
    return type(container)(container_list)


## Constants


def const_None():
    return None


def const_0():
    return 0


def const_1():
    return 1


def const_2():
    return 2


def const_3():
    return 3


def const_4():
    return 4


def const_5():
    return 5


def const_6():
    return 6


def const_7():
    return 7


def const_8():
    return 8


def const_9():
    return 9


def const_10():
    return 10


def const_11():
    return 11


def const_12():
    return 12


def const_13():
    return 13


def const_14():
    return 14


def const_15():
    return 15


def const_16():
    return 16


def const_17():
    return 17


def const_18():
    return 18


def const_19():
    return 19


def const_20():
    return 20


def const_21():
    return 21


def const_22():
    return 22


def const_23():
    return 23


def const_24():
    return 24


def const_25():
    return 25


def const_26():
    return 26


def const_27():
    return 27


def const_28():
    return 28


def const_29():
    return 29


def const_30():
    return 30


def const_true():
    return True


def const_false():
    return False


def const_neg1():
    return -1


def const_neg2():
    return -2


def const_origin():
    return (0, 0)


def const_unity():
    return (1, 1)


def const_down():
    return (1, 0)


def const_right():
    return (0, 1)


def const_up():
    return (-1, 0)


def const_left():
    return (0, -1)


def const_neg_unity():
    return (-1, -1)


def const_up_right():
    return (-1, 1)


def const_down_left():
    return (1, -1)


def const_zero_by_two():
    return (0, 2)


def const_two_by_zero():
    return (2, 0)


def const_two_by_two():
    return (2, 2)


def const_three_by_three():
    return (3, 3)


# primitives


def toinput(x: Grid):
    return x


def tooutput(x: Grid):
    return x


def identity(x: T):
    """identity function"""
    return x


def add(a: Numerical, b: Numerical):
    """addition"""
    if isinstance(a, int) and isinstance(b, int):
        result = a + b
        return result if utils.int_in_bounds(result) else None
    elif isinstance(a, tuple) and isinstance(b, tuple):
        result = (a[0] + b[0], a[1] + b[1])
    elif isinstance(a, int) and isinstance(b, tuple):
        result = (a + b[0], a + b[1])
    else:
        result = (a[0] + b, a[1] + b)
    return result if utils.int_in_bounds(result[0]) and utils.int_in_bounds(result[1]) else None


def subtract(a: Numerical, b: Numerical):
    """subtraction"""
    return add(a, invert(b))


def multiply(a: Numerical, b: Numerical):
    """multiplication"""
    if isinstance(a, int) and isinstance(b, int):
        result = a * b
        return result if utils.int_in_bounds(result) else None
    elif isinstance(a, tuple) and isinstance(b, tuple):
        result = (a[0] * b[0], a[1] * b[1])
    elif isinstance(a, int) and isinstance(b, tuple):
        result = (a * b[0], a * b[1])
    else:
        result = (a[0] * b, a[1] * b)
    return result if utils.int_in_bounds(result[0]) and utils.int_in_bounds(result[1]) else None


def divide(a: Numerical, b: Numerical):
    """floor division"""
    if isinstance(b, int):
        b = utils.ensure_bounds(b, low=1, high=math.inf)
    else:
        b = utils.ensure_bounds(*b, low=1, high=math.inf)
    if isinstance(a, int) and isinstance(b, int):
        return a // b
    elif isinstance(a, tuple) and isinstance(b, tuple):
        return (a[0] // b[0], a[1] // b[1])
    elif isinstance(a, int) and isinstance(b, tuple):
        return (a // b[0], a // b[1])
    return (a[0] // b, a[1] // b)


def invert(n: Numerical):
    """inversion with respect to addition"""
    if isinstance(n, int):
        return -n if utils.int_in_bounds(-n) else None
    else:
        result = (-n[0], -n[1])
        return result if utils.int_in_bounds(result[0]) and utils.int_in_bounds(result[1]) else None


def even(n: Integer):
    """evenness"""
    return n % 2 == 0


def double(n: Numerical):
    """scaling by two"""
    return multiply(n, 2)


def halve(n: Numerical):
    """scaling by one half"""
    return n // 2 if isinstance(n, int) else (n[0] // 2, n[1] // 2)


def flip(b: Boolean):
    """logical not"""
    return not b


def equality(a: T, b: T):
    """equality"""
    return a == b


def contained(value: T, container: Container[T]):
    """element of"""
    return value in container


def combine(a: Container[T], b: Container[T]):
    """union"""
    a_is_grid, b_is_grid = utils.is_grid(a), utils.is_grid(b)
    if isinstance(a, tuple) and isinstance(b, tuple) and (a_is_grid or b_is_grid):
        if a_is_grid and b_is_grid and utils.is_grid(a + b):
            return a + b
        return None
    result = type(a)((*a, *b))
    return result if len(result) <= 900 else None


def intersection(a: FrozenSet[T], b: FrozenSet[T]):
    """returns the intersection of two containers"""
    return a & b


def difference(a: Container[T], b: Container[T]):
    """difference"""
    return type(a)(e for e in a if e not in b)


def dedupe(iterable: Tuple[T]):
    """remove duplicates"""
    if types_.infer_type(iterable) is IntegerTuple:
        return iterable
    return tuple(e for i, e in enumerate(iterable) if iterable.index(e) == i)


def order(container: Container[T], compfunc: Callable[[T], Integer]):
    """order container by custom key"""
    return tuple(sorted(container, key=compfunc))


def repeat(item: T, num: Integer):
    """repetition of item within vector"""
    num = utils.ensure_bounds(num, low=1, high=30)
    return tuple(item for i in range(num))


def greater(a: Integer, b: Integer):
    """greater"""
    return a > b


def less(a: Integer, b: Integer):
    """less"""
    return a < b


def size(container: Container[T]):
    """cardinality"""
    return len(container)


def merge(containers: Container[T]):
    """merging"""
    if (
        len(containers) != 0
        and not hasattr(next(iter(containers)), "__iter__")
        or types_.infer_type(containers) in [Object, Grid, Indices]
    ):
        return containers
    result = type(containers)(e for c in containers for e in c)
    return result if len(result) <= 900 else None


def max_(a: Integer, b: Integer):
    """maximum of two integers"""
    return max(a, b)


def min_(a: Integer, b: Integer):
    """minimum of two integers"""
    return min(a, b)


def maximum(container: IntegerSet):
    """maximum"""
    return max(container, default=0)


def minimum(container: IntegerSet):
    """minimum"""
    return min(container, default=0)


def valmax(container: Container[T], compfunc: Callable[[T], Integer]):
    """maximum by custom function"""
    return compfunc(max(container, key=compfunc, default=None))


def valmin(container: Container[T], compfunc: Callable[[T], Integer]):
    """minimum by custom function"""
    return compfunc(min(container, key=compfunc, default=None))


def argmax(container: Container[T], compfunc: Callable[[T], Integer]):
    """largest item by custom order"""
    return max(container, key=compfunc, default=None)


def argmin(container: Container[T], compfunc: Callable[[T], Integer]):
    """smallest item by custom order"""
    return min(container, key=compfunc, default=None)


def mostcommon(container: Container[T]):
    """most common item"""
    if isinstance(container, frozenset):
        # frozensets are hashed and each item is unique
        return None
    return max(set(container), key=container.count, default=None)


def leastcommon(container: Container[T]):
    """least common item"""
    if isinstance(container, frozenset):
        # frozensets are hashed and each item is unique
        return None
    return min(set(container), key=container.count, default=None)


def initset(value: T):
    """initialize container"""
    return frozenset({value})


def both(a: Boolean, b: Boolean):
    """logical and"""
    return a and b


def either(a: Boolean, b: Boolean):
    """logical or"""
    return a or b


def increment(x: Numerical):
    """incrementing"""
    return x + 1 if isinstance(x, int) else (x[0] + 1, x[1] + 1)


def decrement(x: Numerical):
    """decrementing"""
    return x - 1 if isinstance(x, int) else (x[0] - 1, x[1] - 1)


def crement(x: Numerical):
    """incrementing positive and decrementing negative"""
    if isinstance(x, int):
        return 0 if x == 0 else (x + 1 if x > 0 else x - 1)
    return (
        0 if x[0] == 0 else (x[0] + 1 if x[0] > 0 else x[0] - 1),
        0 if x[1] == 0 else (x[1] + 1 if x[1] > 0 else x[1] - 1),
    )


def sign(x: Numerical):
    """sign"""
    if isinstance(x, int):
        return 0 if x == 0 else (1 if x > 0 else -1)
    return (0 if x[0] == 0 else (1 if x[0] > 0 else -1), 0 if x[1] == 0 else (1 if x[1] > 0 else -1))


def positive(x: Integer):
    """positive"""
    return x > 0


def toivec(i: Integer):
    """vector pointing vertically"""
    return (i, 0)


def tojvec(j: Integer):
    """vector pointing horizontally"""
    return (0, j)


def sfilter(container: Container[T], condition: Callable[[T], Boolean]):
    """keep elements in container that satisfy condition"""
    return type(container)(e for e in container if condition(e))


def mfilter(container: Container[T], function: Callable[[T], Boolean]):
    """filter and merge"""
    return merge(sfilter(container, function))


def extract(container: Container[T], condition: Callable[[T], Boolean]):
    """first element of container that satisfies condition"""
    if len(container) == 0 or not any(condition(e) for e in container):
        return None
    return next(e for e in container if condition(e))


def totuple(container: FrozenSet[T]):
    """conversion to tuple"""
    return tuple(container)


def first(container: Container[T]):
    """first item of container"""
    if len(container) == 0:
        return None
    return next(iter(container))


def last(container: Container[T]):
    """last item of container"""
    if len(container) == 0:
        return None
    return max(enumerate(container))[1]


def insert(value: T, container: FrozenSet[T]):
    """insert item into container"""
    result = container.union(frozenset({value}))
    return result if len(result) <= 900 else None


def remove(value: T, container: Container[T]):
    """remove item from container"""
    return type(container)(e for e in container if e != value)


def other(container: Container[T], value: T):
    """other value in the container"""
    return first(remove(value, container))


def interval(start: Integer, stop: Integer, step: Integer):
    """range"""
    if step == 0:
        step = 1
    if start > stop and step > 0 or start < stop and step < 0:
        result = tuple(range(start, stop, -step))
    else:
        result = tuple(range(start, stop, step))
    return result if len(result) <= 900 else None


def astuple(a: Integer, b: Integer):
    """constructs a tuple"""
    return (a, b)


def product(a: Container[T], b: Container[T]):
    """cartesian product"""
    if len(a) * len(b) > 900:
        return None
    return frozenset((i, j) for j in b for i in a)


def pair(a: Tuple[T], b: Tuple[T]):
    """zipping of two tuples"""
    return tuple(zip(a, b))


def branch(condition: Boolean, if_value: T, else_value: T):
    """if else branching"""
    return if_value if condition else else_value


def compose(outer: Callable[[T], Any], inner: Callable[[Any], T]):
    """function composition"""
    if inner.__code__.co_argcount != 1 or outer.__code__.co_argcount != 1:
        return None
    fn = lambda x: outer(inner(x))
    # Add annotations
    inner_type = types_.infer_type(inner)
    outer_type = types_.infer_type(outer)
    inputs_type, _ = typing.get_args(inner_type)
    _, output_type = typing.get_args(outer_type)
    assert len(inputs_type) == 1
    fn.__annotations__ = {"x": inputs_type[0], "return": output_type}
    return fn


def chain(h: Callable[[T2], Any], g: Callable[[T], T2], f: Callable[[Any], T]):
    """function composition with three functions"""
    if f.__code__.co_argcount != 1 or g.__code__.co_argcount != 1 or h.__code__.co_argcount != 1:
        return None
    fn = lambda x: h(g(f(x)))
    # Add annotations
    f_type = types_.infer_type(f)
    h_type = types_.infer_type(h)
    inputs_type, _ = typing.get_args(f_type)
    _, output_type = typing.get_args(h_type)
    assert len(inputs_type) == 1
    fn.__annotations__ = {"x": inputs_type[0], "return": output_type}
    return fn


def matcher(function: Callable[[Any], T], target: T):
    """construction of equality function"""
    if function.__code__.co_argcount != 1:
        return None
    fn = lambda x: function(x) == target
    # Add annotations
    function_type = types_.infer_type(function)
    inputs_type, _ = typing.get_args(function_type)
    assert len(inputs_type) == 1
    fn.__annotations__ = {"x": inputs_type[0], "return": Boolean}
    return fn


def rbind(
    function: Callable[[Any, T], Any] | Callable[[Any, Any, T], Any] | Callable[[Any, Any, Any, T], Any],
    fixed: T,
):
    """fix the rightmost argument"""
    n = function.__code__.co_argcount
    if n == 2:
        fn = lambda x: function(x, fixed)
    elif n == 3:
        fn = lambda x, y: function(x, y, fixed)
    elif n == 4:
        fn = lambda x, y, z: function(x, y, z, fixed)
    else:
        return None
    # Add annotations
    function_type = types_.infer_type(function)
    inputs_type, output_type = typing.get_args(function_type)
    fn.__annotations__ = {f"arg{i}": input_type for i, input_type in enumerate(inputs_type[:-1])}
    fn.__annotations__["return"] = output_type
    return fn


def lbind(
    function: Callable[[T, Any], Any] | Callable[[T, Any, Any], Any] | Callable[[T, Any, Any, Any], Any],
    fixed: T,
):
    """fix the leftmost argument"""
    n = function.__code__.co_argcount
    if n == 2:
        fn = lambda y: function(fixed, y)
    elif n == 3:
        fn = lambda y, z: function(fixed, y, z)
    elif n == 4:
        fn = lambda y, z, a: function(fixed, y, z, a)
    else:
        return None
    # Add annotations
    function_type = types_.infer_type(function)
    inputs_type, output_type = typing.get_args(function_type)
    fn.__annotations__ = {f"arg{i}": input_type for i, input_type in enumerate(inputs_type[1:])}
    fn.__annotations__["return"] = output_type
    return fn


def power(function: Callable[[T], T], n: Integer):
    """power of function"""
    n = utils.ensure_bounds(n, low=0, high=math.inf)
    if n == 0:
        return identity
    if n == 1:
        return function
    return compose(function, power(function, n - 1))


def fork(outer: Callable[[T2, T3], Any], a: Callable[[T], T2], b: Callable[[T], T3]):
    """creates a wrapper function"""
    if a.__code__.co_argcount != 1 or b.__code__.co_argcount != 1 or outer.__code__.co_argcount != 2:
        return None
    fn = lambda x: outer(a(x), b(x))
    # Add annotations
    a_type = types_.infer_type(a)
    b_type = types_.infer_type(b)
    outer_type = types_.infer_type(outer)
    a_inputs_type, _ = typing.get_args(a_type)
    b_inputs_type, _ = typing.get_args(b_type)
    _, output_type = typing.get_args(outer_type)
    assert len(a_inputs_type) == len(b_inputs_type) == 1
    fn.__annotations__ = {"x": a_inputs_type[0], "return": output_type}
    return fn


def apply(function: Callable[[T], Any], container: Container[T]):
    """apply function to each item in container"""
    if len(container) > 900:
        return None
    return type(container)(function(e) for e in container)


def rapply(functions: Container[Callable[[T], Any]], value: T):
    """apply each function in container to value"""
    if len(functions) > 900:
        return None
    return type(functions)(function(value) for function in functions)


def mapply(function: Callable[[T], Any], container: Container[T]):
    """apply and merge"""
    if len(container) > 900:
        return None
    return merge(apply(function, container))


def papply(function: Callable[[T, T2], Any], a: Tuple[T], b: Tuple[T2]):
    """apply function on two vectors"""
    if len(a) != len(b) or len(a) > 900 or len(b) > 900:
        return None
    return tuple(function(i, j) for i, j in zip(a, b))


def mpapply(function: Callable[[T, T2], Any], a: Tuple[T], b: Tuple[T2]):
    """apply function on two vectors and merge"""
    if len(a) != len(b) or len(a) > 900 or len(b) > 900:
        return None
    return merge(papply(function, a, b))


def prapply(function: Callable[[T, T2], Any], a: Container[T], b: Container[T2]):
    """apply function on cartesian product"""
    if len(a) * len(b) > 900:
        return None
    return frozenset(function(i, j) for j in b for i in a)


def mostcolor(element: Element):
    """most common color"""
    values = [v for r in element for v in r] if isinstance(element, tuple) else [v for v, _ in element]
    return max(set(values), key=values.count, default=None)


def leastcolor(element: Element):
    """least common color"""
    values = [v for r in element for v in r] if isinstance(element, tuple) else [v for v, _ in element]
    return min(set(values), key=values.count, default=None)


def height(piece: Piece):
    """height of grid or patch"""
    if len(piece) == 0:
        return 0
    if isinstance(piece, tuple):
        return len(piece)
    return lowermost(piece) - uppermost(piece) + 1


def width(piece: Piece):
    """width of grid or patch"""
    if len(piece) == 0:
        return 0
    if isinstance(piece, tuple):
        return len(piece[0])
    return rightmost(piece) - leftmost(piece) + 1


def shape(piece: Piece):
    """height and width of grid or patch"""
    return (height(piece), width(piece))


def portrait(piece: Piece):
    """whether height is greater than width"""
    return height(piece) > width(piece)


def colorcount(element: Element, value: Integer):
    """number of cells with color"""
    if isinstance(element, tuple):
        return sum(row.count(value) for row in element)
    value = utils.ensure_bounds(value, low=0, high=9)
    return sum(v == value for v, _ in element)


def colorfilter(objs: Objects, value: Integer):
    """filter objects by color"""
    value = utils.ensure_bounds(value, low=0, high=9)
    return frozenset(obj for obj in objs if next(iter(obj))[0] == value)


def sizefilter(container: Objects | IndicesSet, n: Integer):
    """filter items by size"""
    if "__len__" not in dir(container):
        return container
    n = utils.ensure_bounds(n, low=0, high=math.inf)
    return frozenset(item for item in container if len(item) == n)


def asindices(grid: Grid):
    """indices of all grid cells"""
    if not utils.is_grid(grid):
        return None
    return frozenset((i, j) for i in range(len(grid)) for j in range(len(grid[0])))


def ofcolor(grid: Grid, value: Integer):
    """indices of all grid cells with value"""
    if not utils.is_grid(grid):
        return None
    value = utils.ensure_bounds(value, low=0, high=9)
    return frozenset((i, j) for i, r in enumerate(grid) for j, v in enumerate(r) if v == value)


def ulcorner(patch: Patch):
    """index of upper left corner"""
    return tuple(map(min, zip(*toindices(patch))))


def urcorner(patch: Patch):
    """index of upper right corner"""
    return tuple(map(lambda ix: {0: min, 1: max}[ix[0]](ix[1]), enumerate(zip(*toindices(patch)))))


def llcorner(patch: Patch):
    """index of lower left corner"""
    return tuple(map(lambda ix: {0: max, 1: min}[ix[0]](ix[1]), enumerate(zip(*toindices(patch)))))


def lrcorner(patch: Patch):
    """index of lower right corner"""
    return tuple(map(max, zip(*toindices(patch))))


def crop(grid: Grid, start: IntegerTuple, dims: IntegerTuple):
    """subgrid specified by start and dimension"""
    start = (
        utils.ensure_bounds(start[0], low=0, high=len(grid) - 1),
        utils.ensure_bounds(start[1], low=0, high=len(grid[0]) - 1),
    )
    dims = (
        utils.ensure_bounds(dims[0], low=1, high=len(grid) - start[0]),
        utils.ensure_bounds(dims[1], low=1, high=len(grid[0]) - start[1]),
    )
    return tuple(r[start[1] : start[1] + dims[1]] for r in grid[start[0] : start[0] + dims[0]])


def toindices(patch: Patch):
    """indices of object cells"""
    if len(patch) == 0:
        return frozenset()
    if isinstance(next(iter(patch))[1], tuple):
        result = frozenset(index for value, index in patch)
        return result if len(result) <= 900 else None
    return patch


def recolor(value: Integer, patch: Patch):
    """recolor patch"""
    value = utils.ensure_bounds(value, low=0, high=9)
    return frozenset((value, index) for index in toindices(patch))


def shift(patch: Patch, directions: IntegerTuple):
    """shift patch"""
    if len(patch) == 0 or directions is None:
        return patch
    di, dj = utils.ensure_bounds(*directions, low=-29, high=29)
    if isinstance(next(iter(patch))[1], tuple):
        return frozenset((value, (i + di, j + dj)) for value, (i, j) in patch)
    return frozenset((i + di, j + dj) for i, j in patch)


def normalize(patch: Patch):
    """moves upper left corner to origin"""
    if len(patch) == 0:
        return patch
    return shift(patch, (-uppermost(patch), -leftmost(patch)))


def dneighbors(loc: IntegerTuple):
    """directly adjacent indices"""
    loc = utils.ensure_bounds(*loc, low=0, high=29)
    return frozenset({(loc[0] - 1, loc[1]), (loc[0] + 1, loc[1]), (loc[0], loc[1] - 1), (loc[0], loc[1] + 1)})


def ineighbors(loc: IntegerTuple):
    """diagonally adjacent indices"""
    loc = utils.ensure_bounds(*loc, low=0, high=29)
    return frozenset(
        {
            (loc[0] - 1, loc[1] - 1),
            (loc[0] - 1, loc[1] + 1),
            (loc[0] + 1, loc[1] - 1),
            (loc[0] + 1, loc[1] + 1),
        }
    )


def neighbors(loc: IntegerTuple):
    """adjacent indices"""
    return dneighbors(loc) | ineighbors(loc)


def objects(grid: Grid, univalued: Boolean, diagonal: Boolean, without_bg: Boolean):
    """objects occurring on the grid"""
    if not utils.is_grid(grid):
        return None
    bg = mostcolor(grid) if without_bg else None
    objs = set()
    occupied = set()
    h, w = len(grid), len(grid[0])
    unvisited = asindices(grid)
    diagfun = neighbors if diagonal else dneighbors
    for loc in unvisited:
        if loc in occupied:
            continue
        val = grid[loc[0]][loc[1]]
        if val == bg:
            continue
        obj = {(val, loc)}
        cands = {loc}
        while len(cands) > 0:
            neighborhood = set()
            for cand in cands:
                v = grid[cand[0]][cand[1]]
                if (val == v) if univalued else (v != bg):
                    obj.add((v, cand))
                    occupied.add(cand)
                    neighborhood |= {(i, j) for i, j in diagfun(cand) if 0 <= i < h and 0 <= j < w}
            cands = neighborhood - occupied
        objs.add(frozenset(obj))
    result = frozenset(objs)
    return result if len(result) <= 900 else None


def partition(grid: Grid):
    """each cell with the same value part of the same object"""
    if not utils.is_grid(grid):
        return None
    return frozenset(
        frozenset((v, (i, j)) for i, r in enumerate(grid) for j, v in enumerate(r) if v == value)
        for value in palette(grid)
    )


def fgpartition(grid: Grid):
    """each cell with the same value part of the same object without background"""
    if not utils.is_grid(grid):
        return None
    return frozenset(
        frozenset((v, (i, j)) for i, r in enumerate(grid) for j, v in enumerate(r) if v == value)
        for value in palette(grid) - {mostcolor(grid)}
    )


def uppermost(patch: Patch):
    """row index of uppermost occupied cell"""
    return min((i for i, j in toindices(patch)), default=None)


def lowermost(patch: Patch):
    """row index of lowermost occupied cell"""
    return max((i for i, j in toindices(patch)), default=None)


def leftmost(patch: Patch):
    """column index of leftmost occupied cell"""
    return min((j for i, j in toindices(patch)), default=None)


def rightmost(patch: Patch):
    """column index of rightmost occupied cell"""
    return max((j for i, j in toindices(patch)), default=None)


def square(piece: Piece):
    """whether the piece forms a square"""
    return (
        len(piece) == len(piece[0])
        if isinstance(piece, tuple)
        else height(piece) * width(piece) == len(piece) and height(piece) == width(piece)
    )


def vline(patch: Patch):
    """whether the piece forms a vertical line"""
    return height(patch) == len(patch) and width(patch) == 1


def hline(patch: Patch):
    """whether the piece forms a horizontal line"""
    return width(patch) == len(patch) and height(patch) == 1


def hmatching(a: Patch, b: Patch):
    """whether there exists a row for which both patches have cells"""
    return len(set(i for i, j in toindices(a)) & set(i for i, j in toindices(b))) > 0


def vmatching(a: Patch, b: Patch):
    """whether there exists a column for which both patches have cells"""
    return len(set(j for i, j in toindices(a)) & set(j for i, j in toindices(b))) > 0


def manhattan(a: Patch, b: Patch):
    """closest manhattan distance between two patches"""
    a_indices, b_indices = toindices(a), toindices(b)
    if len(a_indices) == 0 or len(b_indices) == 0:
        return 0
    if len(a_indices) * len(b_indices) > 900:
        return None
    return min(abs(ai - bi) + abs(aj - bj) for ai, aj in a_indices for bi, bj in b_indices)


def adjacent(a: Patch, b: Patch):
    """whether two patches are adjacent"""
    return manhattan(a, b) == 1


def bordering(patch: Patch, grid: Grid):
    """whether a patch is adjacent to a grid border"""
    return (
        uppermost(patch) == 0
        or leftmost(patch) == 0
        or lowermost(patch) == len(grid) - 1
        or rightmost(patch) == len(grid[0]) - 1
    )


def centerofmass(patch: Patch):
    """center of mass"""
    if len(patch) == 0:
        return None
    return tuple(map(lambda x: sum(x) // len(patch), zip(*toindices(patch))))


def palette(element: Element):
    """colors occurring in object or grid"""
    if isinstance(element, tuple):
        result = frozenset({v for r in element for v in r})
    else:
        result = frozenset({v for v, _ in element})
    return result if len(result) <= 900 else None


def numcolors(element: Element):
    """number of colors occurring in object or grid"""
    return len(palette(element))


def color(obj: Object):
    """color of object"""
    if len(obj) == 0:
        return 0
    return next(iter(obj))[0]


def toobject(patch: Patch, grid: Grid):
    """object from patch and grid"""
    if not utils.is_grid(grid):
        return None
    h, w = len(grid), len(grid[0])
    return frozenset((grid[i][j], (i, j)) for i, j in toindices(patch) if 0 <= i < h and 0 <= j < w)


def asobject(grid: Grid):
    """conversion of grid to object"""
    if not utils.is_grid(grid):
        return None
    return frozenset((v, (i, j)) for i, r in enumerate(grid) for j, v in enumerate(r))


def rot90(grid: Grid):
    """quarter clockwise rotation"""
    if not utils.is_grid(grid):
        return None
    return tuple(row for row in zip(*grid[::-1]))


def rot180(grid: Grid):
    """half rotation"""
    if not utils.is_grid(grid):
        return None
    return tuple(tuple(row[::-1]) for row in grid[::-1])


def rot270(grid: Grid):
    """quarter anticlockwise rotation"""
    if not utils.is_grid(grid):
        return None
    return tuple(tuple(row[::-1]) for row in zip(*grid[::-1]))[::-1]


def hmirror(piece: Piece):
    """mirroring along horizontal"""
    if len(piece) == 0:
        return piece
    if isinstance(piece, tuple):
        return piece[::-1]
    d = ulcorner(piece)[0] + lrcorner(piece)[0]
    if isinstance(next(iter(piece))[1], tuple):
        return frozenset((v, (d - i, j)) for v, (i, j) in piece)
    return frozenset((d - i, j) for i, j in piece)


def vmirror(piece: Piece):
    """mirroring along vertical"""
    if len(piece) == 0:
        return piece
    if isinstance(piece, tuple):
        return tuple(row[::-1] for row in piece)
    d = ulcorner(piece)[1] + lrcorner(piece)[1]
    if isinstance(next(iter(piece))[1], tuple):
        return frozenset((v, (i, d - j)) for v, (i, j) in piece)
    return frozenset((i, d - j) for i, j in piece)


def dmirror(piece: Piece):
    """mirroring along diagonal"""
    if len(piece) == 0:
        return piece
    if isinstance(piece, tuple):
        return tuple(zip(*piece))
    a, b = ulcorner(piece)
    if isinstance(next(iter(piece))[1], tuple):
        return frozenset((v, (j - b + a, i - a + b)) for v, (i, j) in piece)
    return frozenset((j - b + a, i - a + b) for i, j in piece)


def cmirror(piece: Piece):
    """mirroring along counterdiagonal"""
    if len(piece) == 0:
        return piece
    if isinstance(piece, tuple):
        return tuple(zip(*(r[::-1] for r in piece[::-1])))
    return vmirror(dmirror(vmirror(piece)))


def fill(grid: Grid, value: Integer, patch: Patch):
    """fill value at indices"""
    if not utils.is_grid(grid) or patch is None:
        return None
    h, w = len(grid), len(grid[0])
    grid_filled = list(list(row) for row in grid)
    for i, j in toindices(patch):
        if 0 <= i < h and 0 <= j < w:
            grid_filled[i][j] = value
    return tuple(tuple(row) for row in grid_filled)


def paint(grid: Grid, obj: Object):
    """paint object to grid"""
    if not utils.is_grid(grid):
        return None
    h, w = len(grid), len(grid[0])
    grid_painted = list(list(row) for row in grid)
    for value, (i, j) in obj:
        if 0 <= i < h and 0 <= j < w:
            grid_painted[i][j] = utils.ensure_bounds(value, low=0, high=9)
    return tuple(tuple(row) for row in grid_painted)


def underfill(grid: Grid, value: Integer, patch: Patch):
    """fill value at indices that are background"""
    if not utils.is_grid(grid):
        return None
    h, w = len(grid), len(grid[0])
    bg = mostcolor(grid)
    grid_filled = list(list(row) for row in grid)
    value = utils.ensure_bounds(value, low=0, high=9)
    for i, j in toindices(patch):
        if 0 <= i < h and 0 <= j < w:
            if grid_filled[i][j] == bg:
                grid_filled[i][j] = value
    return tuple(tuple(row) for row in grid_filled)


def underpaint(grid: Grid, obj: Object):
    """paint object to grid where there is background"""
    if not utils.is_grid(grid):
        return None
    h, w = len(grid), len(grid[0])
    bg = mostcolor(grid)
    grid_painted = list(list(row) for row in grid)
    for value, (i, j) in obj:
        if 0 <= i < h and 0 <= j < w:
            if grid_painted[i][j] == bg:
                grid_painted[i][j] = utils.ensure_bounds(value, low=0, high=9)
    return tuple(tuple(row) for row in grid_painted)


def hupscale(grid: Grid, factor: Integer):
    """upscale grid horizontally"""
    factor = utils.ensure_bounds(factor, low=1, high=30)
    if not utils.is_grid(grid) or len(grid[0]) * factor > 30:
        return None
    upscaled_grid = tuple()
    for row in grid:
        upscaled_row = tuple()
        for value in row:
            upscaled_row = upscaled_row + tuple(value for num in range(factor))
        upscaled_grid = upscaled_grid + (upscaled_row[:30],)
    return upscaled_grid


def vupscale(grid: Grid, factor: Integer):
    """upscale grid vertically"""
    factor = utils.ensure_bounds(factor, low=1, high=30)
    if not utils.is_grid(grid) or len(grid) * factor > 30:
        return None
    upscaled_grid = tuple()
    for row in grid:
        upscaled_grid = upscaled_grid + tuple(row for num in range(factor))
    return upscaled_grid[:30]


def upscale(element: Element, factor: Integer):
    """upscale object or grid"""
    factor = utils.ensure_bounds(factor, low=1, high=30)
    h, w = shape(element)
    if factor * h > 30 or factor * w > 30:
        return None
    if isinstance(element, tuple):
        upscaled_grid = tuple()
        for row in element:
            upscaled_row = tuple()
            for value in row:
                upscaled_row = upscaled_row + tuple(value for num in range(factor))
            upscaled_grid = upscaled_grid + tuple(upscaled_row[:30] for num in range(factor))
        return upscaled_grid[:30]
    else:
        if len(element) == 0:
            return frozenset()
        di_inv, dj_inv = ulcorner(element)
        normed_obj = shift(element, (-di_inv, -dj_inv))
        upscaled_obj = set()
        for value, (i, j) in normed_obj:
            for io in range(factor):
                for jo in range(factor):
                    upscaled_obj.add((value, (i * factor + io, j * factor + jo)))
        return shift(frozenset(upscaled_obj), (di_inv, dj_inv))


def downscale(grid: Grid, factor: Integer):
    """downscale grid"""
    factor = utils.ensure_bounds(factor, low=1, high=30)
    h, w = len(grid), len(grid[0])
    downscaled_grid = tuple()
    for i in range(h):
        downscaled_row = tuple()
        for j in range(w):
            if j % factor == 0:
                downscaled_row = downscaled_row + (grid[i][j],)
        downscaled_grid = downscaled_grid + (downscaled_row,)
    h = len(downscaled_grid)
    downscaled_grid2 = tuple()
    for i in range(h):
        if i % factor == 0:
            downscaled_grid2 = downscaled_grid2 + (downscaled_grid[i],)
    if len(downscaled_grid2) == 0:
        return ((grid[0][0],),)
    return downscaled_grid2


def hconcat(a: Grid, b: Grid):
    """concatenate two grids horizontally"""
    if len(a) != len(b) or not utils.is_grid(a) or not utils.is_grid(b) or len(a[0]) + len(b[0]) > 30:
        return None
    return tuple((i + j)[:30] for i, j in zip(a, b))


def vconcat(a: Grid, b: Grid):
    """concatenate two grids vertically"""
    if len(a[0]) != len(b[0]) or not utils.is_grid(a) or not utils.is_grid(b) or len(a) + len(b) > 30:
        return None
    return (a + b)[:30]


def subgrid(patch: Patch, grid: Grid):
    """smallest subgrid containing object"""
    return crop(grid, ulcorner(patch), shape(patch))


def hsplit(grid: Grid, n: Integer):
    """split grid horizontally"""
    n = utils.ensure_bounds(n, low=1, high=30)
    if not utils.is_grid(grid) or n > len(grid[0]):
        return None
    h, w = len(grid), len(grid[0]) // n
    offset = len(grid[0]) % n != 0
    return tuple(crop(grid, (0, w * i + i * offset), (h, w)) for i in range(n))


def vsplit(grid: Grid, n: Integer):
    """split grid vertically"""
    n = utils.ensure_bounds(n, low=1, high=30)
    if not utils.is_grid(grid) or n > len(grid):
        return None
    h, w = len(grid) // n, len(grid[0])
    offset = len(grid) % n != 0
    return tuple(crop(grid, (h * i + i * offset, 0), (h, w)) for i in range(n))


def cellwise(a: Grid, b: Grid, fallback: Integer):
    """cellwise match of two grids"""
    if not utils.is_grid(a) or not utils.is_grid(b):
        return None
    fallback = utils.ensure_bounds(fallback, low=0, high=9)
    h, w = len(a), len(a[0])
    resulting_grid = tuple()
    for i in range(h):
        row = tuple()
        for j in range(w):
            a_value = a[i][j]
            if i < len(b) and j < len(b[0]) and a_value == b[i][j]:
                value = a_value
            else:
                value = fallback
            row = row + (value,)
        resulting_grid = resulting_grid + (row,)
    return resulting_grid


def replace(grid: Grid, replacee: Integer, replacer: Integer):
    """color substitution"""
    if not utils.is_grid(grid):
        return None
    replacer, replacee = utils.ensure_bounds(replacer, replacee, low=0, high=9)
    return tuple(tuple(replacer if v == replacee else v for v in r) for r in grid)


def switch(grid: Grid, a: Integer, b: Integer):
    """color switching"""
    if not utils.is_grid(grid):
        return None
    a, b = utils.ensure_bounds(a, b, low=0, high=9)
    return tuple(tuple(v if (v != a and v != b) else {a: b, b: a}[v] for v in r) for r in grid)


def center(patch: Patch):
    """center of the patch"""
    return (uppermost(patch) + height(patch) // 2, leftmost(patch) + width(patch) // 2)


def position(a: Patch, b: Patch):
    """relative position between two patches"""
    ia, ja = center(toindices(a))
    ib, jb = center(toindices(b))
    if ia == ib:
        return (0, 1 if ja < jb else -1)
    elif ja == jb:
        return (1 if ia < ib else -1, 0)
    elif ia < ib:
        return (1, 1 if ja < jb else -1)
    elif ia > ib:
        return (-1, 1 if ja < jb else -1)


def index(grid: Grid, loc: IntegerTuple):
    """color at location"""
    i, j = loc
    h, w = len(grid), len(grid[0])
    if not (0 <= i < h and 0 <= j < w):
        return None
    return grid[loc[0]][loc[1]]


def canvas(value: Integer, dimensions: Numerical):
    """grid construction"""
    value = utils.ensure_bounds(value, low=0, high=9)
    if isinstance(dimensions, int):
        dimensions = (dimensions, dimensions)
    dimensions = utils.ensure_bounds(*dimensions)
    return tuple(tuple(value for j in range(dimensions[1])) for i in range(dimensions[0]))


def corners(patch: Patch):
    """indices of corners"""
    return frozenset({ulcorner(patch), urcorner(patch), llcorner(patch), lrcorner(patch)})


def connect(a: IntegerTuple, b: IntegerTuple):
    """line between two points"""
    ai, aj = a
    bi, bj = b
    si = min(ai, bi)
    ei = max(ai, bi) + 1
    sj = min(aj, bj)
    ej = max(aj, bj) + 1
    if any(coord > 60 or coord < -30 for coord in [si, sj, ei, ej]):
        return frozenset()
    if ai == bi:
        result = frozenset((ai, j) for j in range(sj, ej))
    elif aj == bj:
        result = frozenset((i, aj) for i in range(si, ei))
    elif bi - ai == bj - aj:
        result = frozenset((i, j) for i, j in zip(range(si, ei), range(sj, ej)))
    elif bi - ai == aj - bj:
        result = frozenset((i, j) for i, j in zip(range(si, ei), range(ej - 1, sj - 1, -1)))
    else:
        result = frozenset()
    return result if len(result) <= 900 else None


def cover(grid: Grid, patch: Patch):
    """remove object from grid"""
    return fill(grid, mostcolor(grid), toindices(patch))


def trim(grid: Grid):
    """trim border of grid"""
    if not utils.is_grid(grid):
        return None
    h, w = len(grid), len(grid[0])
    if h < 3 or w < 3:
        return grid
    return tuple(r[1:-1] for r in grid[1:-1])


def move(grid: Grid, obj: Object, offset: IntegerTuple):
    """move object on grid"""
    return paint(cover(grid, obj), shift(obj, offset))


def tophalf(grid: Grid):
    """upper half of grid"""
    if len(grid) < 2:
        return grid
    return grid[: len(grid) // 2]


def bottomhalf(grid: Grid):
    """lower half of grid"""
    if len(grid) < 2:
        return grid
    return grid[len(grid) // 2 + len(grid) % 2 :]


def lefthalf(grid: Grid):
    """left half of grid"""
    if len(grid[0]) < 2:
        return grid
    return rot270(tophalf(rot90(grid)))


def righthalf(grid: Grid):
    """right half of grid"""
    if len(grid[0]) < 2:
        return grid
    return rot270(bottomhalf(rot90(grid)))


def vfrontier(location: Numerical):
    """vertical frontier"""
    if isinstance(location, int):
        location = (0, location)
    location = utils.ensure_bounds(*location, low=0, high=29)
    return frozenset((i, location[1]) for i in range(30))


def hfrontier(location: Numerical):
    """horizontal frontier"""
    if isinstance(location, int):
        location = (location, 0)
    location = utils.ensure_bounds(*location, low=0, high=29)
    return frozenset((location[0], j) for j in range(30))


def backdrop(patch: Patch):
    """indices in bounding box of patch"""
    if len(patch) == 0:
        return frozenset({})
    indices = toindices(patch)
    si, sj = ulcorner(indices)
    ei, ej = lrcorner(patch)
    if any(coord > 60 or coord < -30 for coord in [si, sj, ei, ej]):
        return frozenset({})
    result = frozenset((i, j) for i in range(si, ei + 1) for j in range(sj, ej + 1))
    return result if len(result) <= 900 else None


def delta(patch: Patch):
    """indices in bounding box but not part of patch"""
    if len(patch) == 0:
        return frozenset({})
    result = backdrop(patch) - toindices(patch)
    return result if len(result) <= 900 else None


# def gravitate(source: Patch, destination: Patch):  # TODO: fix this function that sometimes takes too long
#     """direction to move source until adjacent to destination"""
#     source_i, source_j = center(source)
#     destination_i, destination_j = center(destination)
#     i, j = 0, 0
#     if vmatching(source, destination):
#         i = 1 if source_i < destination_i else -1
#     else:
#         j = 1 if source_j < destination_j else -1
#     direction = (i, j)
#     gravitation_i, gravitation_j = i, j
#     if source_i == destination_i and source_j == destination_j:
#         gravitation_i += -height(source)
#         source = shift(source, (-height(source), 0))
#     maxcount = 30
#     count = 0
#     while not adjacent(source, destination) and count < maxcount:
#         count += 1
#         gravitation_i += i
#         gravitation_j += j
#         source = shift(source, direction)
#     return (gravitation_i - i, gravitation_j - j)


def inbox(patch: Patch):
    """inbox for patch"""
    ai, aj = uppermost(patch) + 1, leftmost(patch) + 1
    bi, bj = lowermost(patch) - 1, rightmost(patch) - 1
    si, sj = min(ai, bi), min(aj, bj)
    ei, ej = max(ai, bi), max(aj, bj)
    if any(coord > 60 or coord < -30 for coord in [si, sj, ei, ej]):
        return frozenset({})
    vlines = {(i, sj) for i in range(si, ei + 1)} | {(i, ej) for i in range(si, ei + 1)}
    hlines = {(si, j) for j in range(sj, ej + 1)} | {(ei, j) for j in range(sj, ej + 1)}
    result = frozenset(vlines | hlines)
    return result if len(result) <= 900 else None


def outbox(patch: Patch):
    """outbox for patch"""
    ai, aj = uppermost(patch) - 1, leftmost(patch) - 1
    bi, bj = lowermost(patch) + 1, rightmost(patch) + 1
    si, sj = min(ai, bi), min(aj, bj)
    ei, ej = max(ai, bi), max(aj, bj)
    if any(coord > 60 or coord < -30 for coord in [si, sj, ei, ej]):
        return frozenset({})
    vlines = {(i, sj) for i in range(si, ei + 1)} | {(i, ej) for i in range(si, ei + 1)}
    hlines = {(si, j) for j in range(sj, ej + 1)} | {(ei, j) for j in range(sj, ej + 1)}
    result = frozenset(vlines | hlines)
    return result if len(result) <= 900 else None


def box(patch: Patch):
    """outline of patch"""
    if len(patch) == 0:
        return patch
    ai, aj = ulcorner(patch)
    bi, bj = lrcorner(patch)
    si, sj = min(ai, bi), min(aj, bj)
    ei, ej = max(ai, bi), max(aj, bj)
    vlines = {(i, sj) for i in range(si, ei + 1)} | {(i, ej) for i in range(si, ei + 1)}
    hlines = {(si, j) for j in range(sj, ej + 1)} | {(ei, j) for j in range(sj, ej + 1)}
    result = frozenset(vlines | hlines)
    return result if len(result) <= 900 else None


def shoot(start: IntegerTuple, direction: IntegerTuple):
    """line from starting point and direction"""
    if direction not in [
        (0, 1),
        (0, -1),
        (1, 0),
        (-1, 0),
        (1, 1),
        (-1, -1),
        (1, -1),
        (-1, 1),
        (0, 2),
        (2, 0),
        (2, 2),
        (3, 3),
    ]:
        return None
    end_x, end_y = start
    for _ in range(42):
        end_x += direction[0]
        end_y += direction[1]
        if not (0 <= end_x < 30 and 0 <= end_y < 30):
            break
    result = connect(start, (end_x, end_y))
    return result if len(result) <= 900 else None


def occurrences(grid: Grid, obj: Object):
    """locations of occurrences of object in grid"""
    if not utils.is_grid(grid):
        return None
    occurrences = set()
    normed = normalize(obj)
    h, w = len(grid), len(grid[0])
    for i in range(h):
        for j in range(w):
            occurs = True
            for v, (a, b) in shift(normed, (i, j)):
                if 0 <= a < h and 0 <= b < w:
                    if grid[a][b] != v:
                        occurs = False
                        break
                else:
                    occurs = False
                    break
            if occurs:
                occurrences.add((i, j))
    return frozenset(occurrences)


def frontiers(grid: Grid):
    """set of frontiers"""
    if not utils.is_grid(grid):
        return None
    h, w = len(grid), len(grid[0])
    row_indices = tuple(i for i, r in enumerate(grid) if len(set(r)) == 1)
    column_indices = tuple(j for j, c in enumerate(dmirror(grid)) if len(set(c)) == 1)
    hfrontiers = frozenset({frozenset({(grid[i][j], (i, j)) for j in range(w)}) for i in row_indices})
    vfrontiers = frozenset({frozenset({(grid[i][j], (i, j)) for i in range(h)}) for j in column_indices})
    result = hfrontiers | vfrontiers
    return result if len(result) <= 900 else None


def compress(grid: Grid):
    """removes frontiers from grid"""
    if not utils.is_grid(grid):
        return None
    ri = tuple(i for i, r in enumerate(grid) if len(set(r)) == 1)
    ci = tuple(j for j, c in enumerate(dmirror(grid)) if len(set(c)) == 1)
    return tuple(tuple(v for j, v in enumerate(r) if j not in ci) for i, r in enumerate(grid) if i not in ri)


def hperiod(obj: Object):
    """horizontal periodicity"""
    normalized = normalize(obj)
    w = width(normalized)
    for p in range(1, w):
        offsetted = shift(normalized, (0, -p))
        pruned = frozenset({(c, (i, j)) for c, (i, j) in offsetted if j >= 0})
        if pruned.issubset(normalized):
            return p
    return w


def vperiod(obj: Object):
    """vertical periodicity"""
    normalized = normalize(obj)
    h = height(normalized)
    for p in range(1, h):
        offsetted = shift(normalized, (-p, 0))
        pruned = frozenset({(c, (i, j)) for c, (i, j) in offsetted if i >= 0})
        if pruned.issubset(normalized):
            return p
    return h


class _LocalFunctions:
    """Class to add functions to the module to make pickling work"""

    @classmethod
    def add_functions(cls, *args):
        for function in args:
            setattr(cls, function.__name__, function)
            function.__qualname__ = cls.__qualname__ + "." + function.__name__


# Make primitive nodes available as constants
for primitive_name, _function in inspect.getmembers(sys.modules[__name__], inspect.isfunction):
    if not primitive_name.startswith(("_", "rand_", "const_", "toinput", "tooutput")):

        def make_const_fn(_function) -> Callable[[], Callable]:
            def const_fn() -> Callable:
                return _function

            return const_fn

        const_fn = make_const_fn(_function)
        const_fn.__name__ = f"const_{primitive_name}"
        const_fn.__annotations__["return"] = types_.infer_type(_function)
        setattr(sys.modules[__name__], const_fn.__name__, const_fn)
        _LocalFunctions.add_functions(const_fn)  # To make pickle work for the dataloader with multiprocessing
        del const_fn, make_const_fn


# Add the generators as random primitives
exec(GENERATORS_SRC_CODE)
function_names = [name.split("(")[0] for name in GENERATORS_SRC_CODE.split() if name.startswith("generate")]
assert len(function_names) == 400
for fn_name in function_names:
    assert fn_name in globals()

    def make_rand_fn(fn_name) -> Callable[[], Grid]:
        def rand_fn() -> Grid:
            generate_fn = globals()[fn_name]
            return generate_fn(0, 1)["input"]

        return rand_fn

    rand_fn = make_rand_fn(fn_name)
    rand_fn.__name__ = f"rand_{fn_name}"
    rand_fn.__annotations__["return"] = Grid
    setattr(sys.modules[__name__], rand_fn.__name__, rand_fn)
    _LocalFunctions.add_functions(rand_fn)  # To make pickle work for the dataloader with multiprocessing
    del rand_fn, make_rand_fn


ALL_PRIMITIVES = {
    primitive_name: {
        "inputs": [
            (input_name, input_type.annotation)
            for input_name, input_type in inspect.signature(function).parameters.items()
        ],
        "fn": function,
    }
    for primitive_name, function in inspect.getmembers(sys.modules[__name__], inspect.isfunction)
    if not primitive_name.startswith(("_", "generate_", "unifint"))
}

RANDOM_PRIMITIVES = {name for name in ALL_PRIMITIVES if name.startswith("rand_")}

CONSTANT_PRIMITIVES = {name for name in ALL_PRIMITIVES if name.startswith("const_")}
