from einx._src.util import pytree

class Tracer:
    def __init__(self, origin):
        if origin is not None and not isinstance(origin, Application):
            raise TypeError(f"origin must be Application, not {type(origin)}")
        self.origin = origin

class Graph:
    def __init__(self, inputs, output, name=None):
        self.inputs = inputs
        self.output = output
        self.name = name

    def __eq__(self, other):
        if not isinstance(other, Graph):
            return False
        if self.name != other.name:
            return False
        if self.inputs != other.inputs:
            return False
        if self.output != other.output:
            return False
        return True

class Application:
    def __init__(self, inputs, output):
        self.output = output(origin=self)
        self.inputs = inputs


class Cast(Application):
    def __init__(self, input, output):
        super().__init__(inputs=[input], output=output)
        self.input = input

    def _tracer_transform(self, transform):
        return Cast(transform(self.input), lambda origin: pytree.map(lambda x: x._tracer_type(origin), self.output))

    def __eq__(self, other):
        if not isinstance(other, Cast):
            return False
        if self.input != other.input:
            return False
        if pytree.map(lambda x: x._tracer_type, self.output) != pytree.map(lambda x: x._tracer_type, other.output):
            return False
        return True

def cast(input, output):
    return Cast(input, output).output
