import einx._src.tracer as tracer
from functools import partial

class einsum:
    def __init__(self, einsum):
        self.einsum = einsum

    def __call__(self, subscripts, *tensors):
        exprs = subscripts.split("->")[0].split(",")
        if len(exprs) != len(tensors):
            raise ValueError(f"Expected {len(exprs)} tensors, got {len(tensors)}")
        values = {}
        for i, (expr, tensor) in enumerate(zip(exprs, tensors)):
            expr = expr.strip().replace(" ", "")
            if len(expr) != len(tensor.shape):
                raise ValueError(
                    f"Expected {len(expr)} axes, got {len(tensor.shape)} for {i}-th "
                    "(zero-based) input tensor"
                )
            for axis, value in zip(expr, tensor.shape):
                if axis in values:
                    if values[axis] != value:
                        raise ValueError(
                            f"Got conflicting values for axis {axis}: "
                            f"{values[axis]} and {value}"
                        )
                else:
                    values[axis] = value
        expr_out = subscripts.split("->")[-1].strip().replace(" ", "")
        shape_out = tuple(values[axis] for axis in expr_out)

        x = self.einsum(subscripts, *tensors)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape_out))
        return x