import os
import time
import logging
import functools

import rich
import jax
from jax import core
from jax._src.util import safe_map
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
from jax.experimental import multihost_utils

from meshflow.unifyshard import unifyir

from meshflow.autoflow import AutoFlowSolver
from .sharding_interpreter import MFJaxShardingAnn
from .bridge import jax2mf_bridge

logger = logging.getLogger(__name__)

JAX_DEVICE_MESH = None
INPUT_STRATEGY = None


def set_device_mesh(device_mesh):
    global JAX_DEVICE_MESH
    JAX_DEVICE_MESH = device_mesh

    mesh_shape = device_mesh.device_ids.shape
    if mesh_shape[0] == 1:
        unifyir.DEVICE_MESH_1D = 0
    elif mesh_shape[1] == 1:
        unifyir.DEVICE_MESH_1D = 1


def get_device_mesh():
    global JAX_DEVICE_MESH
    return JAX_DEVICE_MESH


def to_shape_array(x):
    if isinstance(x, jax.Array) and not jax.core.is_opaque_dtype(x.dtype):
        return core.ShapedArray(shape=x.shape, dtype=x.dtype)
    else:
        return x


def materialize(x):
    if isinstance(x, core.ShapedArray):
        key = jax.random.PRNGKey(seed=42)
        if x.dtype.name in ["float64", "float32", "float16"]:
            return jax.random.normal(key, shape=x.shape, dtype=x.dtype)
        elif x.dtype.name in ["int32", "unint32", "int64", "uint64", "uint8"]:
            return jax.random.randint(key, shape=x.shape, dtype=x.dtype, minval=1, maxval=8)
        elif x.dtype.name in ["bool"]:
            return jax.random.normal(key, shape=x.shape) > 1.
        else:
            return jax.numpy.empty(shape=x.shape, dtype=x.dtype)
    return x


def convert(strategy, mesh, val):
    axis_names = mesh.axis_names

    s1, s2 = strategy[0], strategy[1]
    ndim = len(val.shape)
    mesh_shape = [None] * ndim

    # we use strategy except val.shape -> ()
    if ndim > 0:
        for idx, s_ in enumerate([s1, s2]):
            if s_.state == unifyir.SPMD.SHARD:
                dim = s_.args["dim"]
                if mesh_shape[dim] == None:
                    mesh_shape[dim] = axis_names[idx]
                else:
                    mesh_shape[dim] = mesh.axis_names

    return NamedSharding(mesh, PartitionSpec(*mesh_shape))


def _get_input_strategy(opt_strategy, unify_graph):
    partial_strategy = {}
    for op in reversed(unify_graph.op_list):
        op_key = op.unique_key()
        if op_key in opt_strategy:
            for idx, var in enumerate(op.invars):
                if var in unify_graph.input_list:
                    strategy = opt_strategy[op_key]['strategy']['invars_sharding'][idx]
                    partial_strategy[var] = strategy

    partial_strategy_list = []

    for var in unify_graph.input_list:
        if var in partial_strategy:
            partial_strategy_list.append(partial_strategy[var])
        else:
            partial_strategy_list.append(
                [unifyir.SPMD(unifyir.SPMD.REPLICATE),
                 unifyir.SPMD(unifyir.SPMD.REPLICATE)])

    return partial_strategy_list


def _get_local_array(array, specs):
    mesh = get_device_mesh()
    mesh_shape = mesh.device_ids.shape
    axis_names = mesh.axis_names
    for idx in range(len(specs)):
        for axis_idx, axis in enumerate(axis_names):
            if specs[idx] == axis:
                # FIXME: use first chunk here, need to use rank here
                array = jax.numpy.array_split(array, mesh_shape[axis_idx], axis=idx)[0]

    return array


def shard_module(flatten_args):
    global JAX_DEVICE_MESH, INPUT_STRATEGY

    for i in range(len(flatten_args)):
        _strategy = convert(INPUT_STRATEGY[i], JAX_DEVICE_MESH, flatten_args[i])
        flatten_args[i] = materialize(flatten_args[i])
        _local_array = _get_local_array(flatten_args[i], _strategy.spec)
        flatten_args[i] = multihost_utils.host_local_array_to_global_array(
            _local_array, JAX_DEVICE_MESH, _strategy.spec)

    return flatten_args


def add_sharding_jaxpr(jaxpr, consts, shard_strategy, args):
    env = {}

    def read(var):
        if type(var) is core.Literal:
            return var.val
        return env[var]

    def write(var, val):
        env[var] = val

    # Args now correspond to Jaxpr outvars
    safe_map(write, jaxpr.invars, args)
    safe_map(write, jaxpr.constvars, consts)

    # Looping backward
    for eqn in jaxpr.eqns:
        subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)

        onelevel_key = eqn.primitive.__str__()
        if eqn.primitive.__str__() == "custom_jvp_call":
            onelevel_key += "[" + subfuns[0].f.args[0].eqns[0].params['name'] + "]"
        if eqn.primitive.__str__() == "xla_call":
            onelevel_key += "[" + eqn.params['name'] + "]"

        invars = [var for var in eqn.invars if isinstance(var, jax.core.Var)]

        #  outvars are now invars
        invals = safe_map(read, eqn.invars)

        unique_key = f"{onelevel_key}_{invars}"

        mesh = get_device_mesh()

        if unique_key in shard_strategy:
            invars_sharding = shard_strategy[unique_key]['strategy']['invars_sharding']

            sharding_idx = 0
            for idx in range(len(invals)):
                if isinstance(invals[idx], jax.interpreters.partial_eval.DynamicJaxprTracer):
                    strategy_ = convert(invars_sharding[sharding_idx], mesh, invals[idx])
                    invals[idx] = jax.lax.with_sharding_constraint(invals[idx], strategy_)
                    sharding_idx += 1

        outval = eqn.primitive.bind(*subfuns, *invals, **bind_params)

        if isinstance(outval, jax.interpreters.partial_eval.DynamicJaxprTracer):
            outval = [outval]

        safe_map(write, eqn.outvars, outval)

    outvals = safe_map(read, jaxpr.outvars)

    return outvals


def get_opt_strategy(func, *args, **kwargs):

    global INPUT_STRATEGY

    closed_jaxpr = jax.make_jaxpr(func)(*args, **kwargs)

    os.environ["NVIDIA_TF32_OVERRIDE"] = "1"

    start_t = time.perf_counter()
    sharding_interpreter = MFJaxShardingAnn(closed_jaxpr.jaxpr)
    sharding_info, shape_info = sharding_interpreter.run(closed_jaxpr.literals, *args, **kwargs)
    logger.info(f"[MFJaxShardingAnn.run]: {time.perf_counter() - start_t} s.")

    if logging.root.level <= logging.DEBUG:
        rich.print("sharding_info:\n", sharding_info)
        rich.print("shape_info:\n", shape_info)

    os.environ["NVIDIA_TF32_OVERRIDE"] = "0"

    unify_graph = jax2mf_bridge(closed_jaxpr.jaxpr, sharding_info, shape_info)

    if logging.root.level <= logging.INFO:
        rich.print(unify_graph)

    device_mesh = get_device_mesh()
    solver = AutoFlowSolver(device_mesh.device_ids.shape)
    solver.add_graph(unify_graph)
    start_t = time.perf_counter()
    opt_strategy = solver.ilp_optimize()
    logger.info(f"[AutoFlowSolver.ilp_optimize]: {time.perf_counter() - start_t} s.")
    # start_t = time.perf_counter()
    # beam_search_strategy = solver.beam_search()
    # logger.info(f"[AutoFlowSolver.beam_search]: {time.perf_counter() - start_t} s.")

    INPUT_STRATEGY = _get_input_strategy(opt_strategy, unify_graph)

    if logging.root.level <= logging.INFO:
        rich.print(opt_strategy)

    return opt_strategy


def meshflow_shard(fun, shard_strategy={}):

    @functools.wraps(fun)
    def wrapped(*args, **kwargs):
        # Since we assume unary functions, we won't worry about flattening and
        # unflattening arguments.
        closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)

        logger.debug(f"[closed_jaxpr.jaxpr]: {closed_jaxpr.jaxpr}")
        logger.debug(f"[closed_jaxpr.literals]: {closed_jaxpr.literals}")

        flatten_args, _ = jax.tree_util.tree_flatten(args)

        out = add_sharding_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, shard_strategy,
                                 flatten_args)
        if len(out) == 1:
            return out[0]
        else:
            return tuple(out)

    return wrapped
