import jax

from meshflow.unifyshard.unifyir import UnifyGraph, UnifyNode, UnifyVar


def jax2mf_bridge(jaxpr: jax.core.Jaxpr, sharding_info, meta_info) -> UnifyGraph:
    unify_graph = UnifyGraph(jaxpr)

    for in_var in jaxpr.invars:
        unify_var = UnifyVar(name=in_var.__str__(),
                             shape=meta_info[in_var.__str__()]["shape"],
                             dtype=meta_info[in_var.__str__()]["dtype"])
        unify_graph.append_input(unify_var)

    for in_var in jaxpr.constvars:
        unify_var = UnifyVar(name=in_var.__str__(),
                             shape=meta_info[in_var.__str__()]["shape"],
                             dtype=meta_info[in_var.__str__()]["dtype"])
        unify_graph.append_input(unify_var)

    for eqn in jaxpr.eqns:
        subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
        onelevel_key = eqn.primitive.__str__()
        if eqn.primitive.__str__() == "custom_jvp_call":
            onelevel_key += "[" + subfuns[0].f.args[0].eqns[0].params['name'] + "]"
        if eqn.primitive.__str__() == "xla_call":
            onelevel_key += "[" + eqn.params['name'] + "]"

        # print(onelevel_key, eqn.invars, eqn.outvars, bind_params)
        node_sharding_info = None
        if onelevel_key in sharding_info:
            abstract_list = []
            for var in eqn.invars:
                if var.__str__() in meta_info:
                    abstract_list.append(
                        (meta_info[var.__str__()]["shape"], meta_info[var.__str__()]["dtype"]))
                elif type(var) is jax.core.Literal:
                    abstract_list.append(var.val)
                else:
                    abstract_list.append(var)
            twolevel_key = str(abstract_list) + str(bind_params)
            if twolevel_key in sharding_info[onelevel_key]:
                node_sharding_info = sharding_info[onelevel_key][twolevel_key]

        outvars_ = []
        for var in eqn.outvars:
            name = var.__str__()
            outvars_.append(
                UnifyVar(name=name,
                         shape=meta_info[name]["shape"],
                         dtype=meta_info[name]["dtype"].name))

        # print(eqn.invars)
        unify_node = UnifyNode(
            name=onelevel_key,
            invars=[var.__str__() for var in eqn.invars if isinstance(var, jax.core.Var)],
            outvars=outvars_,
            sharding_info=node_sharding_info)
        unify_graph.append_op(unify_node)

    unify_graph.mark_output([var.__str__() for var in jaxpr.outvars])

    return unify_graph
