import torch
from torch.fx.node import Node, _get_qualified_name, map_arg
from spmd import distribute_tensor, Replicate
from spmd.tensor.api import DTensor

from meshflow.unifyshard import unifyir
from meshflow.utils.testing import MockDeviceMesh

TORCH_DEVICE_MESH = None


def set_device_mesh(device_mesh):
    global TORCH_DEVICE_MESH
    TORCH_DEVICE_MESH = device_mesh

    if device_mesh.size(0) == 1:
        unifyir.DEVICE_MESH_1D = 0
    elif device_mesh.size(1) == 1:
        unifyir.DEVICE_MESH_1D = 1


def get_device_mesh():
    global TORCH_DEVICE_MESH
    return TORCH_DEVICE_MESH


def materialize(x, device):
    if isinstance(x, torch.Tensor):
        if x.dtype == torch.bool:
            return torch.rand(x.size(), dtype=torch.float, device=device) > 0.5
        elif torch.is_floating_point(x):
            return torch.rand(x.size(), dtype=x.dtype, device=device)
        else:
            return torch.randint(high=8, size=x.size(), dtype=x.dtype, device=device)
    return x


def redist_tensor_func(input_tensor, specs):
    if isinstance(input_tensor, DTensor) and input_tensor.size() != torch.Size([0]):
        device_mesh = get_device_mesh()
        if device_mesh.size(0) * device_mesh.size(1) > 1 and specs != input_tensor._spec.placements:
            return input_tensor.redistribute(device_mesh, specs)
    return input_tensor


def insert_spmd_head(body):
    return ["from spmd import Shard, Replicate\n", *body]


def to_dtensor(input_tensor):
    if isinstance(input_tensor, torch.Tensor):
        device_mesh = get_device_mesh()
        if device_mesh and not isinstance(device_mesh, MockDeviceMesh):
            if input_tensor.device == torch.device("meta"):
                input_tensor = materialize(input_tensor, device="cuda")
            return distribute_tensor(input_tensor, device_mesh, [Replicate()] * device_mesh.ndim)
    return input_tensor


def fix_in_gragh_tensor(fx_module: torch.fx.GraphModule):

    for node in fx_module.graph.nodes:
        if node.op == 'get_attr':
            with fx_module.graph.inserting_after(node):
                to_dtensor_node = fx_module.graph.call_function(to_dtensor, args=(node, ))

                node.replace_all_uses_with(to_dtensor_node)

                to_dtensor_node.update_arg(0, node)

        if node.op == 'call_function':
            if "torch.ops.aten.scalar_tensor.default" in _get_qualified_name(
                    node.target) or "torch.ops.aten.ones.default" in _get_qualified_name(
                        node.target):
                with fx_module.graph.inserting_after(node):
                    to_dtensor_node = fx_module.graph.call_function(to_dtensor, args=(node, ))

                    node.replace_all_uses_with(to_dtensor_node)

                    to_dtensor_node.update_arg(0, node)

    fx_module.recompile()

    return fx_module


def replace_subsequence_use(node, arg_, redist_node):
    users_node = list(node.users.keys())
    node_next = node.next

    def maybe_replace_node(n: Node) -> Node:
        if n == arg_:
            return redist_node
        else:
            return n

    while node_next.name != "":
        if node_next in users_node:
            new_args = map_arg(node_next.args, maybe_replace_node)
            new_kwargs = map_arg(node_next.kwargs, maybe_replace_node)
            assert isinstance(new_args, tuple)
            assert isinstance(new_kwargs, dict)
            node_next.args = new_args
            node_next.kwargs = new_kwargs
        node_next = node_next.next


def sharding_transform(fx_module, sharding_strategy):
    for node in fx_module.graph.nodes:
        if node.op == 'call_function':
            if node.name in sharding_strategy:
                node_strategy = sharding_strategy[node.name]
                strategy_idx = 0
                for arg_idx, arg_ in enumerate(node.args):
                    if isinstance(arg_, Node):
                        arg_strategy = node_strategy[strategy_idx]
                        strategy_idx += 1
                        with fx_module.graph.inserting_before(node):
                            redist_node = fx_module.graph.call_function(redist_tensor_func,
                                                                        args=(arg_, arg_strategy))

                            # FIXME: meybe introduce redundancy communication, update_arg for all subsequence use
                            node.update_arg(arg_idx, redist_node)
                            replace_subsequence_use(node, arg_, redist_node)

    fx_module.graph.on_generate_code(lambda _: insert_spmd_head)

    fx_module.recompile()

    # (fix) %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
    fx_module = fix_in_gragh_tensor(fx_module)

    return fx_module
