import einx._src.namedtensor.stage1 as stage1
import einx._src.namedtensor.stage2 as stage2
import einx._src.namedtensor.stage3 as stage3
from einx._src.namedtensor import solve
from einx._src.namedtensor import NamedTensor
from einx._src.error import Invocation
import numpy as np
import functools
import types
import einx._src.tracer as tracer
import einx._src.adapter as adapter
import uuid
from functools import partial
from .error import SemanticError

def _to_el_expr(expr):
    if isinstance(expr, list):
        newlist = [_to_el_expr(e) for e in expr]
        newlist = [expr for expr in newlist if expr.ndim != 0]
        return newlist
    elif isinstance(expr, stage1.List):
        return stage1.List.create(
            _to_el_expr(expr.children),
            expr.begin_pos,
            expr.end_pos,
        )
    elif isinstance(expr, stage1.Axis):
        return stage1.List([])
    elif isinstance(expr, stage1.FlattenedAxis):
        inner = _to_el_expr(expr.inner)
        if inner.ndim == 0:
            return stage1.List([])
        else:
            return stage1.FlattenedAxis.create(
                inner, expr.begin_pos, expr.end_pos
            )
    elif isinstance(expr, stage1.ConcatenatedAxis):
        # ConcatenatedAxis cannot be used with brackets
        assert not any(stage1.is_in_brackets(c) for c in expr.nodes())
        return stage1.List([])
    elif isinstance(expr, stage1.Brackets):
        return expr.inner.__deepcopy__()
    elif isinstance(expr, stage1.Ellipsis):
        return stage1.Ellipsis.create(
            _to_el_expr(expr.inner),
            expr.begin_pos,
            expr.end_pos,
            expr.ellipsis_id,
        )
    elif isinstance(expr, stage1.Op):
        return stage1.Op(
            [_to_el_expr(c) for c in expr.children],
            expr.begin_pos,
            expr.end_pos,
        )
    elif isinstance(expr, stage1.Args):
        return stage1.Args(
            [_to_el_expr(c) for c in expr.children],
            expr.begin_pos,
            expr.end_pos,
        )
    else:
        raise TypeError(f"Invalid expression type {type(expr)}")

def _parse_op(description, el_op, invocation, allow_concat=False, implicit_output=None, mark_reduced_axes=False, allow_duplicate_el_axes=True):
    op = stage1.parse_op(description)

    if not allow_concat:
        # Disallow concatenation
        if any(isinstance(expr, stage1.ConcatenatedAxis) for expr in op.nodes()):
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_concat(op),
                message="The concatenation operator (+) is not allowed in this function.\n%EXPR%",
            )

    exprs_in = op.children[0].children

    if len(op.children) == 1:
        if implicit_output == "same":
            # Same input as output
            exprs_out = [expr_in.__deepcopy__() for expr_in in exprs_in]
        elif implicit_output == "reduce":
            # Remove brackets for output expression
            exprs_out = [stage1.remove(expr_in, stage1.Brackets, keep_children=False) for expr_in in exprs_in]
        elif implicit_output == "superset":
            # Find superset expression
            if len(exprs_in) == 1:
                # Only one input -> use it as output
                exprs_out = [exprs_in[0].__deepcopy__()]
            else:
                # Use one of the input expression if it contains the axis names of
                # all others and if this choice is unique
                in_axis_names = [
                    {expr.name for expr in root.nodes() if isinstance(expr, stage1.Axis) and expr.value != 1}
                    for root in exprs_in
                ]

                valid_parents = set()
                for i, parent in enumerate(in_axis_names):
                    for j, child in enumerate(in_axis_names):
                        if i != j and not child.issubset(parent):
                            break
                    else:
                        # Found valid parent
                        valid_parents.add(exprs_in[i])

                if len(valid_parents) != 1:
                    raise SemanticError(
                        invocation=invocation,
                        message="The output expression is missing in this operation. einx allows implicitly determining the output expression"
                        " in this operation according to the rule: If one of the input expressions contains the axis names of all other input"
                        " expressions (excluding 1s) and if this choice is unique, then this expression is used as output expression. However, no unique"
                        " input expression was found.\n%EXPR%",
                    )
                exprs_out = [valid_parents.pop().__deepcopy__()]
        elif isinstance(implicit_output, int):
            exprs_out = [exprs_in[implicit_output].__deepcopy__()]
        elif isinstance(implicit_output, (tuple, list)) and all(isinstance(i, int) for i in implicit_output):
            exprs_out = [exprs_in[i].__deepcopy__() for i in implicit_output]
        elif implicit_output == "argfind":
            def _to_output(expr):
                bracket_num = len([e for e in expr.nodes() if isinstance(e, stage1.Brackets)])
                assert bracket_num > 0
                if bracket_num == 1:
                    def _replace(expr):
                        if isinstance(expr, stage1.Brackets):
                            return stage1.Brackets.create(stage1.Axis.create(f"output.axis", ))
                    return stage1.map(expr, _replace, include_children=False)
                else:
                    raise SemanticError(
                        invocation=invocation,
                        pos=invocation.indicator.get_pos_for_brackets(expr),
                        message=f"The output expression is missing in this operation. The output expression can only be determined implicitly, if the input expression contains exactly one usage of brackets ([]).\n%EXPR%",
                    )
            exprs_out = [_to_output(expr_in) for expr_in in exprs_in]
        elif implicit_output is None:
            # Expect input and output
            if len(op.children) != 2:
                raise SemanticError(
                    invocation=invocation,
                    message=f"The operation expects both input and output expressions, but no '->' was found.\n%EXPR%",
                )
            exprs_out = op.children[1].children
        else:
            assert False
    else:
        exprs_out = op.children[1].children
    op = stage1.Op(
        [stage1.Args(exprs_in), stage1.Args(exprs_out)],
    )

    # Get elementary operation
    el_subop = _to_el_expr(op)
    if el_op is None:
        el_op = el_subop
    elif isinstance(el_op, str):
        el_op = stage1.parse_op(el_op)
    elif callable(el_op):
        el_op = el_op(el_subop)
        el_op = stage1.parse_op(el_op)
    else:
        assert False
    assert len(el_op.children) == 2 # TODO: requires better exception here for vmap with custom el_op

    # Check number of input and output expressions
    if len(el_op.children[0].children) != len(exprs_in):
        raise SemanticError(
            invocation=invocation,
            message=f"The operation expects {len(el_op.children[0].children)} input expression(s), but found {len(exprs_in)}.\n%EXPR%",
        )
    if len(el_op.children[1].children) != len(exprs_out):
        raise SemanticError(
            invocation=invocation,
            message=f"The operation expects {len(el_op.children[1].children)} output expression(s), but found {len(exprs_out)}.\n%EXPR%",
        )

    # Check bracket usage
    def _to_ordinal_str(i):
        if i == 0:
            return "1st"
        elif i == 1:
            return "2nd"
        elif i == 2:
            return "3rd"
        else:
            return f"{i + 1}th"
    def check(i, el_arg, el_subarg, inoutput):
        if el_arg.ndim == 0 and el_subarg.ndim != 0:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_brackets(el_subarg),
                message=f"Brackets ([]) are not allowed in the {_to_ordinal_str(i)} {inoutput} expression of this operation.\n%EXPR%",
            )
        elif el_arg.ndim != 0 and el_subarg.ndim == 0:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_exprs(el_subarg),
                message=f"The {_to_ordinal_str(i)} {inoutput} expression of this operation requires brackets, but no brackets were found.\n%EXPR%",
            )
    for i, (el_arg, el_subarg) in enumerate(zip(el_op.children[0].children, el_subop.children[0].children)):
        check(i, el_arg, el_subarg, "input")
    for i, (el_arg, el_subarg) in enumerate(zip(el_op.children[1].children, el_subop.children[1].children)):
        check(i, el_arg, el_subarg, "output")

    if mark_reduced_axes and not any(isinstance(expr, stage1.Brackets) for expr_in in exprs_in for expr in expr_in.nodes()):
        assert len(exprs_out) == 1
        expr_out = exprs_out[0]
        # If no brackets appear in exprs_in, mark all axes that don't appear in expr_out.
        axes_names_out = {
            axis.name for axis in expr_out.nodes() if isinstance(axis, stage1.Axis) # TODO: specific error reporting for these implicit conventions
        }
        def _mark(expr):
            if isinstance(expr, stage1.Axis) and expr.name not in axes_names_out:
                return True
            else:
                return False
        exprs_in = [stage1.map(
            expr_in,
            lambda expr: stage1.Brackets(expr) if _mark(expr) else None,
            include_children=False,
        ) for expr_in in exprs_in]

    # Check that no two vectorized axes in any output have the same name
    for expr_out in exprs_out:
        axis_names = [expr.name for expr in expr_out.nodes() if isinstance(expr, stage1.Axis) and not stage1.is_in_brackets(expr)]
        if len(axis_names) != len(set(axis_names)):
            duplicates = set([name for name in axis_names if axis_names.count(name) > 1])
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames([expr_out], duplicates),
                message=f"The output expression must not contain multiple vectorized axes with the same name.\n%EXPR%",
            )

    if not allow_duplicate_el_axes:
        # Check that no two marked axes in any tensor have the same name
        for expr in exprs_in + exprs_out:
            axis_names = [expr.name for expr in expr.nodes() if isinstance(expr, stage1.Axis) and stage1.is_in_brackets(expr)]
            if len(axis_names) != len(set(axis_names)):
                duplicates = set([name for name in axis_names if axis_names.count(name) > 1])
                raise SemanticError(
                    invocation=invocation,
                    pos=invocation.indicator.get_pos_for_axisnames([expr], duplicates),
                    message=f"The expression must not contain multiple axes with the same name in brackets ([]).\n%EXPR%",
                )

    return exprs_in, exprs_out

def _semantic_checks_dot(exprs_in, exprs_out, invocation):
    expr_out = exprs_out[0]
    # Ensure that all marked axes appear in exactly two input expressions
    def is_marked_axis(expr):
        return isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)
    marked_axis_names = {expr.name for expr_in in exprs_in for expr in expr_in.nodes() if is_marked_axis(expr)}
    invalid_axis_names = []
    for axis_name in marked_axis_names:
        count = 0
        for expr_in in exprs_in:
            if axis_name in {axis.name for axis in expr_in.nodes() if isinstance(axis, stage3.Axis)}:
                count += 1
        if count != 2:
            invalid_axis_names.append(axis_name)
    if len(invalid_axis_names) > 0:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(exprs_in + [expr_out], invalid_axis_names),
            message=f"All contracted axes must appear in exactly two input expressions.\n%EXPR%",
        )

def _semantic_checks_get_at(exprs_in, exprs_out, invocation):
    expr_out = exprs_out[0]
    if len(exprs_in) < 2:
        raise SemanticError(
            invocation=invocation,
            message=f"The operation expects at least 2 input expressions, but found {len(exprs_in)}.\n%EXPR%",
        )
    all_exprs = list(exprs_in) + [expr_out]
    tensor_expr = exprs_in[0]
    coords_exprs = exprs_in[1:]

    # Ensure that at most one axis is marked in each coordinate expression
    for coord_expr in coords_exprs:
        marked_axisnames = [expr.name for expr in coord_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
        if len(marked_axisnames) > 1:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(all_exprs, marked_axisnames),
                message="Each coordinate expression must contain at most one axis in brackets.\n%EXPR%",
            )

    # Ensure that summed number of coordinates is equal to the number of marked axes in first input expression
    n = 0
    for coord_expr in coords_exprs:
        marked_axis = [expr for expr in coord_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
        assert len(marked_axis) <= 1
        if len(marked_axis) == 1:
            n += marked_axis[0].value
        else:
            n += 1
    marked_axes = [expr for expr in tensor_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
    if len(marked_axes) != n:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(all_exprs, [expr.name for expr in marked_axes]),
            message="The number of coordinates must match the number of axes in brackets in the first input expression.\n%EXPR%",
        )

def _semantic_checks_sort(exprs_in, exprs_out, invocation):
    assert len(exprs_in) == 1 and len(exprs_out) == 1
    for expr in exprs_in + exprs_out:
        marked_axisnames = [expr.name for expr in expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
        if len(marked_axisnames) != 1:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, marked_axisnames),
                message="The expression for this operation must contain exactly one axis in brackets.\n%EXPR%",
            )

def _semantic_checks_update_at(exprs_in, exprs_out, invocation):
    expr_out = exprs_out[0]
    if len(exprs_in) < 2:
        raise SemanticError(
            invocation=invocation,
            message=f"The operation expects at least 3 input expressions, but found {len(exprs_in)}.\n%EXPR%",
        )

    all_exprs = list(exprs_in) + [expr_out]
    tensor_expr = exprs_in[0]
    coords_exprs = exprs_in[1:-1]
    update_expr = exprs_in[-1]

    # Ensure same set of axes is marked in first input and output expression
    input_axisnames = {expr.name for expr in tensor_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)}
    output_axisnames = {expr.name for expr in expr_out.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)}
    if input_axisnames != output_axisnames:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(all_exprs, input_axisnames.symmetric_difference(output_axisnames)),
            message="The first input and output expressions must have the same set of axes in brackets.\n%EXPR%",
        )

    # Ensure that at most one axis is marked in each coordinate expression (aside from reduced axes that are also marked in updates)
    for coord_expr in coords_exprs:
        marked_axisnames = {expr.name for expr in coord_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)}
        if len(marked_axisnames) > 1:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(all_exprs, marked_axisnames),
                message="Each coordinate expression must contain at most one coordinate axis in brackets.\n%EXPR%",
            )

    # Ensure marked axes in (1) target tensor, and (2) coordinate and update expressions are non-overlapping
    target_axisnames = {expr.name for expr in tensor_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)}
    coordupdate_axisnames = {expr.name for expr in coords_exprs + [update_expr] for expr in expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)}
    intersection = target_axisnames.intersection(coordupdate_axisnames)
    if len(intersection) > 0:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(all_exprs, intersection),
            message="Axes may not appear in brackets both in the first input tensor and in a coordinate or update expression.\n%EXPR%",
        )

    # Ensure that summed number of coordinates is equal to the number of marked axes in first input expression
    n = 0
    for coord_expr in coords_exprs:
        marked_axis = {expr for expr in coord_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)}
        assert len(marked_axis) <= 1
        if len(marked_axis) == 1:
            n += marked_axis.pop().value
        else:
            n += 1
    marked_axes = [expr for expr in tensor_expr.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
    if len(marked_axes) != n:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(all_exprs, [expr.name for expr in marked_axes]),
            message="The number of coordinates must match the number of axes in brackets in the first input expression.\n%EXPR%",
        )

def _semantic_checks_argfind(exprs_in, exprs_out, invocation):
    if len(exprs_in) != 1:
        raise SemanticError(
            invocation=invocation,
            message=f"The operation expects exactly one input expression, but found {len(exprs_in)}.\n%EXPR%",
        )
    if len(exprs_out) != 1:
        raise SemanticError(
            invocation=invocation,
            message=f"The operation expects exactly one output expression, but found {len(exprs_in)}.\n%EXPR%",
        )
    expr_in = exprs_in[0]
    expr_out = exprs_out[0]
    all_exprs = [expr_in, expr_out]

    # Ensure that at most one axis is marked in output expression
    marked_output_axes = [expr for expr in expr_out.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
    if len(marked_output_axes) > 1:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(all_exprs, [a.name for a in marked_output_axes]),
            message="The output expression must contain at most one axis in brackets.\n%EXPR%",
        )
    marked_output_axis = marked_output_axes[0] if len(marked_output_axes) == 1 else None

    # Ensure that number of marked axes in input expression is equal to value of marked axis in output expression
    marked_input_axes = [expr for expr in expr_in.nodes() if isinstance(expr, stage3.Axis) and stage3.is_in_brackets(expr)]
    if marked_output_axis is None:
        if len(marked_input_axes) != 1:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(all_exprs, [expr.name for expr in marked_input_axes]),
                message=f"If no axis is marked in the output expression, exactly one axis must be marked in the input expression, but found {len(marked_input_axes)} marked axes.\n%EXPR%",
            )
    else:
        if len(marked_input_axes) != marked_output_axis.value:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(all_exprs, [a.name for a in marked_input_axes] + [marked_output_axis.name]),
                message=f"The number of axes in brackets in the input expression ({len(marked_input_axes)}) must match the value of the marked axis in the output expression ({marked_output_axis.value}).\n%EXPR%",
            )

def _cast_shape(tensor, expr):
    if tensor.shape is None:
        assert isinstance(tensor, tracer.signature.classical.ConvertibleTensor)
        return tracer.cast(tensor, lambda origin: tracer.signature.classical.ConvertibleTensor(origin, tensor.concrete, expr.shape))
    else:
        assert tuple(tensor.shape) == tuple(expr.shape)
        return tensor



def op(op, el_op=None, allow_concat=False, implicit_output=None, mark_reduced_axes=False, check=None, cse_in_brackets=True, kwargnames=[], equations_stage3=None, allow_nontrivial_unmarked_reduced_axes=False, allow_duplicate_el_axes=True):
    def inner(description, *tensors, **kwargs):
        invocation = Invocation(
            description,
            name=op.__name__ if hasattr(op, "__name__") else "operation",
            tensors=tensors,
            kwargs=kwargs,
        ) # Used for error reporting

        exprs_in, exprs_out = _parse_op(
            description,
            el_op,
            invocation=invocation,
            allow_concat=allow_concat,
            implicit_output=implicit_output,
            mark_reduced_axes=mark_reduced_axes,
            allow_duplicate_el_axes=allow_duplicate_el_axes,
        )

        if len(exprs_in) != len(tensors):
            raise ValueError(
                f"The operation is defined with {len(exprs_in)} input expression(s), but {len(tensors)} input tensor(s) are given as argument(s).",
            )

        used_axis_names = {expr.name for expr in exprs_in + exprs_out for expr in expr.nodes() if isinstance(expr, stage1.Axis)}
        if any(name in used_axis_names for name in kwargnames):
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, kwargnames),
                message=f"The following axis names may not be used in the expression, since they are keyword arguments of the elementary operation: {', '.join(kwargnames)}.\n%EXPR%",
            )

        parameters = {key: value for key, value in kwargs.items() if key not in kwargnames}
        kwargs = {key: value for key, value in kwargs.items() if key in kwargnames}


        exprs_in, exprs_out = solve(
            exprs_in,
            exprs_out,
            [tensor.shape for tensor in tensors],
            invocation,
            parameters,
            cse_in_brackets=cse_in_brackets,
            equations_stage3=partial(equations_stage3, invocation=invocation) if equations_stage3 is not None else None,
        )

        if not allow_nontrivial_unmarked_reduced_axes:
            axis_names_in = {axis.name for expr_in in exprs_in for axis in expr_in.nodes() if isinstance(axis, stage3.Axis) and not stage3.is_in_brackets(axis) and axis.value != 1}
            axis_names_out = {axis.name for expr_out in exprs_out for axis in expr_out.nodes() if isinstance(axis, stage3.Axis) and not stage3.is_in_brackets(axis) and axis.value != 1}
            axis_names_reduced = axis_names_in - axis_names_out
            if len(axis_names_reduced) > 0:
                raise SemanticError(
                    invocation=invocation,
                    pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, axis_names_reduced),
                    message=f"The input axes {axis_names_reduced} must appear in the output expression.\n%EXPR%",
                )

        if check is not None:
            check(exprs_in, exprs_out, invocation)

        tensors = [_cast_shape(tensor, expr_in) for tensor, expr_in in zip(tensors, exprs_in)]

        tensors = [NamedTensor(tensor, expr_in) for tensor, expr_in in zip(tensors, exprs_in)]
        try:
            tensors = op(*tensors, out=exprs_out[0] if len(exprs_out) == 1 else exprs_out, **kwargs)
        except SemanticError as e:
            raise SemanticError(
                invocation=invocation,
                pos=e.pos,
                message=e.message,
            )

        if len(exprs_out) > 1:
            return tuple([t.value for t in tensors])
        else:
            return tensors.value

        return tensor
    inner.__name__ = op.__name__
    inner.__qualname__ = op.__qualname__
    inner.__doc__ = op.__doc__
    inner.__module__ = op.__module__
    return inner

def _elementwise_el_op(op, n_in=None, n_out=None):
    if n_in is None:
        n_in = len(op.children[0].children)
    if n_out is None:
        n_out = len(op.children[1].children)
    args_in = ", ".join("" for i in range(n_in))
    args_out = ", ".join("" for i in range(n_out))
    return f"{args_in} -> {args_out}"

def id(op, **kwargs):
    return globals()["op"](
        op,
        el_op=_elementwise_el_op,
        allow_concat=True,
        **kwargs,
    )

def dot(op, **kwargs):
    return globals()["op"](
        op,
        el_op=lambda op: f"{op.children[0]} ->",
        mark_reduced_axes=True,
        allow_duplicate_el_axes=False,
        check=_semantic_checks_dot,
        **kwargs,
    )

def _equations_stage3_index_at(exprs_in, exprs_out, invocation, is_update):
    tensor_expr = exprs_in[0]
    coord_exprs = exprs_in[1:-1] if is_update else exprs_in[1:]
    coords_axes = []
    for coord_expr in coord_exprs:
        marked_axes = [expr for expr in coord_expr.nodes() if isinstance(expr, stage2.Axis) and stage2.is_in_brackets(expr)]
        if len(marked_axes) == 0:
            coords_axes.append(stage2.Axis(f"unnamed.{uuid.uuid4().int}", 1, ellipsis_indices=[]))
        elif len(marked_axes) == 1:
            coords_axes.append(marked_axes[0])
        else:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, [expr.name for expr in marked_axes]),
                message="Each coordinate expression must contain at most one axis in brackets.\n%EXPR%",
            )
    marked_coord_axis = stage2.ConcatenatedAxis.create(coords_axes, ellipsis_indices=[])
    marked_axes_in = [expr for expr in tensor_expr.nodes() if isinstance(expr, stage2.Axis) and stage2.is_in_brackets(expr)]

    if marked_coord_axis.value is not None and marked_coord_axis.value != len(marked_axes_in):
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, [marked_coord_axis.name] + [expr.name for expr in marked_axes_in]),
            message=f"The sum of the lengths of marked coordinate axes ({marked_coord_axis.value}) must match the number of marked axes in the first input expression ({len(marked_axes_in)}).\n%EXPR%",
        )

    return [
        stage3.Equation(
            marked_coord_axis,
            stage2.Axis(f"unnamed.{uuid.uuid4().int}", len(marked_axes_in), ellipsis_indices=[])
        )
    ]

def get_at(op, **kwargs):
    return globals()["op"](
        op,
        el_op=lambda op: f"{op.children[0]} ->",
        check=_semantic_checks_get_at,
        equations_stage3=partial(_equations_stage3_index_at, is_update=False),
        **kwargs,
    )

def update_at(op, **kwargs):
    return globals()["op"](
        op,
        el_op=lambda op: f"{', '.join([str(c) for c in op.children[0].children[:-1]])}, -> {op.children[1].children[0]}",
        check=_semantic_checks_update_at,
        equations_stage3=partial(_equations_stage3_index_at, is_update=True),
        implicit_output=0,
        allow_nontrivial_unmarked_reduced_axes=True,
        **kwargs,
    )

def elementwise(op, **kwargs):
    return globals()["op"](
        op,
        el_op=partial(_elementwise_el_op, n_out=1),
        implicit_output="superset",
        **kwargs,
    )

def reduce(op, **kwargs):
    return globals()["op"](
        op,
        el_op=lambda op: f"{op.children[0].children[0]} ->",
        implicit_output="reduce",
        mark_reduced_axes=True,
        **kwargs,
    )

def _equations_stage3_argfind(exprs_in, exprs_out, invocation):
    marked_axes_in = [expr for expr in exprs_in[0].nodes() if isinstance(expr, stage2.Axis) and stage2.is_in_brackets(expr)]
    marked_axes_out = [expr for expr in exprs_out[0].nodes() if isinstance(expr, stage2.Axis) and stage2.is_in_brackets(expr)]
    if len(marked_axes_out) == 0:
        if len(marked_axes_in) != 1:
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, [a.name for a in marked_axes_in]),
                message=f"If no axis is marked with brackets in the output expression, exactly one axis must be marked in the input expression, but found {len(marked_axes_in)} marked axes.\n%EXPR%",
            )
        return []
    elif len(marked_axes_out) == 1:
        marked_axis_out = marked_axes_out[0]
        if marked_axis_out.value is not None and marked_axis_out.value != len(marked_axes_in):
            raise SemanticError(
                invocation=invocation,
                pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, [marked_axis_out.name] + [a.name for a in marked_axes_in]),
                message=f"The value of the marked axis in the output expression ({marked_axis_out.value}) must match the number of marked axes in the input expression ({len(marked_axes_in)}).\n%EXPR%",
            )
        return [
            stage3.Equation(
                marked_axis_out,
                stage2.Axis(f"unnamed.{uuid.uuid4().int}", len(marked_axes_in), ellipsis_indices=[])
            )
        ]
    else:
        raise SemanticError(
            invocation=invocation,
            pos=invocation.indicator.get_pos_for_axisnames(exprs_in + exprs_out, [a.name for a in marked_axes_out]),
            message="The output expression must contain at most one axis in brackets.\n%EXPR%",
        )

def argfind(op, **kwargs):
    def el_op(op):
        output_is_empty = op.children[1].children[0].expansion() == 0
        if not output_is_empty:
            return f"{op.children[0].children[0]} -> a{uuid.uuid4().int}"
        else:
            return f"{op.children[0].children[0]} ->"
    return globals()["op"](
        op,
        el_op=el_op,
        check=_semantic_checks_argfind,
        implicit_output="argfind",
        equations_stage3=_equations_stage3_argfind,
        **kwargs,
    )

def preserve_shape(op, **kwargs):
    return globals()["op"](
        op,
        el_op=lambda op: f"{op.children[0].children[0]} -> {op.children[0].children[0]}",
        implicit_output=0,
        **kwargs,
    )

def ops(namedtensor_ops):
    return \
        {
            name: elementwise(namedtensor_ops[name], kwargnames=[])
            for name in adapter.ops.elementwise
        } | {
            name: reduce(namedtensor_ops[name], kwargnames=[], cse_in_brackets=True)
            for name in adapter.ops.reduce
        } | {
            name: update_at(namedtensor_ops[name], kwargnames=[])
            for name in adapter.ops.update_at
        } | {
            name: argfind(namedtensor_ops[name], kwargnames=[])
            for name in adapter.ops.argfind
        } | {
            "get_at": get_at(namedtensor_ops["get_at"], kwargnames=[]),
            "dot": dot(namedtensor_ops["dot"], kwargnames=[]),
            "id": id(namedtensor_ops["id"], kwargnames=[]),
            "roll": preserve_shape(namedtensor_ops["roll"], kwargnames=["shift"]),
            "flip": preserve_shape(namedtensor_ops["flip"], kwargnames=[]),
            "sort": preserve_shape(namedtensor_ops["sort"], kwargnames=[], check=_semantic_checks_sort),
            "argsort": preserve_shape(namedtensor_ops["argsort"], kwargnames=[], check=_semantic_checks_sort),
            "softmax": preserve_shape(namedtensor_ops["softmax"], kwargnames=[]),
            "log_softmax": preserve_shape(namedtensor_ops["log_softmax"], kwargnames=[]),
        }