from einx._src.adapter.einx_from_namedtensor import _parse_op
from einx._src.adapter.einx_from_namedtensor import solve as _solve2
import einx._src.namedtensor.stage3 as stage3
from collections import defaultdict
import numpy as np
import numpy.typing as npt
from typing import Mapping, Optional
from .api import _get_listtuple_shape, _is_scalar
from .types import Tensor
from einx._src.error import Invocation
import einx._src.tracer as tracer
import types

def _exprs_to_axes(exprs):
    values = defaultdict(list)
    for root in exprs:
        for expr in root.nodes():
            if isinstance(expr, stage3.Axis):
                tokens = expr.name.split(".")
                values[tokens[0]].append((tuple(int(t) for t in tokens[1:]), expr.value))

    values2 = {}
    for name, xs in values.items():
        shape = np.amax([coord for coord, value in xs], axis=0) + 1
        value = np.zeros(shape, dtype="int32")
        for coord, v in xs:
            value[coord] = v
        if value.shape == ():
            value = int(value)
        values2[name] = value

    return values2

def _solve(description, tensor_shapes, parameters, reraise, cse):
    invocation = Invocation(
        description,
        name="operation",
        tensors=[tracer.signature.classical.Tensor(None, shape=shape) if shape is not None else tracer.signature.classical.ConvertibleTensor(None, shape=None, concrete=types.SimpleNamespace(type=None)) for shape in tensor_shapes],
        kwargs={},
    )
    try:
        exprs_in, exprs_out = _parse_op(
            f"{description} ->",
            el_op=None,
            invocation=invocation,
            allow_concat=True,
        )
        exprs_in, exprs_out = _solve2(
            exprs_in,
            exprs_out,
            tensor_shapes,
            invocation,
            parameters,
            cse_concat=True,
            cse=cse,
        )
    except:
        if reraise:
            raise
        else:
            return None
    return exprs_in

def _get_shape(tensor):
    if tensor is None:
        return None
    try:
        return tuple(int(x) for x in tensor.shape)
    except:
        pass
    if isinstance(tensor, (tuple, list)):
        return _get_listtuple_shape(tensor)
    elif _is_scalar(tensor):
        return ()
    elif callable(tensor):
        return None
    else:
        raise ValueError(f"Unsupported type: {type(tensor)}") # TODO:

def solve_shapes(
    description: str, *tensors: Tensor, **parameters: npt.ArrayLike
) -> Optional[Mapping[str, npt.ArrayLike]]:
    exprs = _solve(
        description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=True
    )
    return tuple(expr.shape for expr in exprs)

def solve_axes(
    description: str, *tensors: Tensor, **parameters: npt.ArrayLike
) -> Optional[Mapping[str, npt.ArrayLike]]:
    exprs = _solve(
        description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=False
    )
    return _exprs_to_axes(exprs)

solve = solve_axes

def matches(
    description: str, *tensors: Tensor, cse: bool = True, **parameters: npt.ArrayLike
) -> bool:
    try:
        solve_shapes(description, *tensors, **parameters)
        return True
    except:
        return False

def check( # TODO: add docs
    description: str, *tensors: Tensor, cse: bool = True, **parameters: npt.ArrayLike
) -> None:
    _solve(
        description, [_get_shape(tensor) for tensor in tensors], parameters, reraise=True, cse=True
    )