import einx._src.namedtensor.stage3 as stage3
from einx._src.namedtensor import NamedTensor

def _expr_to_einsumstr(exprs_in, expr_out):
    einsum_variables = {}

    def get_einsum_variable(key):
        if key in einsum_variables:
            return einsum_variables[key]
        else:
            v = chr(ord("a") + len(einsum_variables))
            if ord(v) > ord("z"):
                raise ValueError(f"The function only supports up to {ord('z') - ord('a') + 1} unique input axes")
            einsum_variables[key] = v
            return v

    def to_einsum(expr):
        axes = [a for a in expr if isinstance(a, stage3.Axis)]
        return "".join(get_einsum_variable(a.name) for a in axes)

    einsum_str = (
        ",".join(to_einsum(expr) for expr in exprs_in)
        + "->"
        + to_einsum(expr_out)
    )
    return einsum_str

def dot(einsum):
    def dot(*tensors, out):
        exprs_in = [t.expr for t in tensors]
        tensors = [t.value for t in tensors]

        tensor = einsum(
            _expr_to_einsumstr(exprs_in, out),
            *tensors,
        )

        return NamedTensor(tensor, out)
    return dot
