from einx._src.namedtensor import NamedTensor
import einx._src.tracer as tracer
from typing import Callable
import functools
import inspect
import types
from einx._src.util.functools import use_name_of

def _call_tensorfactory(tensor, expected_type, kwargs):
    expr = tensor.expr
    tensor = tensor.value

    if isinstance(tensor, tracer.signature.classical.ConvertibleTensor) and issubclass(tensor.concrete.type, Callable):
        # Determine arguments that are passed to the tensor factory
        shape = tuple(tensor.shape)
        if kwargs is None:
            kwargs = {}
        has_var_kwargs = any(
            param.kind in [inspect.Parameter.VAR_KEYWORD]
            for param in tensor.concrete.parameters.values()
        )
        def use_parameter(name):
            return has_var_kwargs or (
                name in tensor.concrete.parameters and
                tensor.concrete.parameters[name].kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
            )
        kwargs = {name: value for name, value in kwargs.items() if use_parameter(name)}

        # Call the factory
        tensor = tracer.signature.python.call(tensor, shape, **kwargs)

        # Assert the output type and shape
        tensor = tracer.signature.python.assert_(
            tensor,
            tracer.signature.python.builtins.isinstance(tensor, expected_type),
            f"Invalid type as output of tensor factory", # TODO:
        )
        tensor = tracer.signature.python.assert_(
            tensor,
            tracer.signature.python.equal(tracer.signature.python.builtins.tuple(tensor.shape), shape),
            f"Expected shape {shape} as output of tensor factory", # TODO:
        )
        tensor = tracer.cast(tensor, lambda origin: tracer.signature.classical.Tensor(origin, shape=shape))

    return NamedTensor(tensor, expr)

class namedtensor_calltensorfactory:
    @staticmethod
    def op(op, expected_type, kwargs=None):
        factory_kwargs = kwargs if kwargs is not None else {}
        @use_name_of(op)
        def inner(*tensors, out, **kwargs):
            signature = types.SimpleNamespace(
                exprs_in=tuple(t.expr for t in tensors),
                exprs_out=tuple(out) if isinstance(out, (tuple, list)) else (out,),
            )
            signature = tracer.signature.python.constant(signature)
            tensors = [_call_tensorfactory(tensor, expected_type, kwargs={"signature": signature, "arg_index": arg_index} | factory_kwargs) for arg_index, tensor in enumerate(tensors)]
            return op(*tensors, out=out, **kwargs)
        return inner

    @staticmethod
    def ops(ops, expected_type, kwargs=None):
        if kwargs is None:
            kwargs = {}
        return {name: namedtensor_calltensorfactory.op(op, expected_type, kwargs={"name": name} | kwargs) for name, op in ops.items()}
