import spmd
import torch
from torch.fx.node import Node, _get_qualified_name
import torch.distributed.distributed_c10d as c10d
from spmd import distribute_tensor

from meshflow.unifyshard.combination import ReduceOp
from meshflow.unifyshard import unifyir
from meshflow.unifyshard.unifyir import UnifyGraph, UnifyNode, UnifyVar
from meshflow.utils import rsetattr, rgetattr

from .passes.sharding import get_device_mesh

from .model.gpt import GPT

ABSTRACT_DTYPE = {
    torch.float32: "float32",
    torch.float32: "float16",
    torch.int32: "int32",
    torch.int64: "int64",
    torch.bool: "bool",
    torch.uint8: "uint8",
    torch.complex64: "complex64",
}


def to_torch_spmd(unify_spmd):
    if unify_spmd.state == unifyir.SPMD.SHARD:
        return spmd.Shard(dim=unify_spmd.args["dim"])
    elif unify_spmd.state == unifyir.SPMD.PARTIAL:
        mapping_ops = {
            ReduceOp.SUM: c10d.ReduceOp.RedOpType.SUM,
            ReduceOp.MAX: c10d.ReduceOp.RedOpType.MAX,
            ReduceOp.MIN: c10d.ReduceOp.RedOpType.MIN,
        }
        return spmd.tensor.placement_types._Partial(reduce_op=mapping_ops[unify_spmd.args["ops"]])
    elif unify_spmd.state == unifyir.SPMD.REPLICATE:
        return spmd.Replicate()


def materialize(x, device):
    if isinstance(x, torch.Tensor):
        if x.dtype == torch.bool:
            return torch.rand(x.size(), dtype=torch.float, device=device) > 0.5
        elif torch.is_floating_point(x):
            return torch.rand(x.size(), dtype=x.dtype, device=device)
        else:
            return torch.randint(high=8, size=x.size(), dtype=x.dtype, device=device)
    return x


def shard_module(model, input_, input_strategy, device="cuda"):
    mesh = get_device_mesh()

    input_strategy = [[to_torch_spmd(i) for i in var_strategy] for var_strategy in input_strategy]

    idx = 0
    for name in dict(model.named_parameters()):

        tensor_data = rgetattr(model, name).data
        if tensor_data.device == torch.device("meta"):
            tensor_data = materialize(tensor_data, device=device)

        rsetattr(
            model, name,
            torch.nn.parameter.Parameter(distribute_tensor(tensor_data, mesh,
                                                           input_strategy[idx])))

        with torch.no_grad():
            rsetattr(model, name + ".grad", torch.empty_like(rgetattr(model, name).data))
        idx += 1

    for name in dict(model.named_buffers()):

        tensor_data = rgetattr(model, name).data
        if tensor_data.device == torch.device("meta"):
            tensor_data = materialize(tensor_data, device=device)

        rsetattr(model, name, distribute_tensor(tensor_data, mesh, input_strategy[idx]))
        idx += 1

    shard_input_ = []
    for tensor in input_:
        tensor = materialize(tensor, device=device)
        shard_input_.append(distribute_tensor(tensor, mesh, input_strategy[idx]))
        idx += 1

    # handle causal_mask for GPT
    if isinstance(model, GPT):
        for idx in range(len(model.blocks)):

            tensor_data = model.blocks[idx].attn.core_attention.causal_mask

            if tensor_data.device == torch.device("meta"):
                seq_len = tensor_data.shape[-1]
                tensor_data = torch.tril(
                    torch.ones((seq_len, seq_len), dtype=torch.uint8,
                               device=device)).view(1, 1, seq_len, seq_len).bool()

            model.blocks[idx].attn.core_attention.causal_mask = tensor_data

    return shard_input_


def torch2mf_bridge(fx_module: torch.fx.GraphModule, sharding_info, meta_info) -> UnifyGraph:
    unify_graph = UnifyGraph(fx_module)
    for node in fx_module.graph.nodes:
        if node.op == "call_function":
            node_sharding_info = None
            op_name = _get_qualified_name(node.target)
            if op_name == "_operator.getitem":
                args_name, idx = node.args
                unify_graph.rename_var(f"{args_name}__{idx}", node.name)
                continue
            if op_name in sharding_info:
                args_meta = []
                for arg in node.args:
                    if isinstance(arg, Node):
                        args_meta.append(
                            torch.empty(meta_info[arg.name]["shape"],
                                        dtype=meta_info[arg.name]["dtype"],
                                        device="meta"))
                    else:
                        args_meta.append(arg)
                args_meta = str(tuple(args_meta)) + ' | ' + str(node.kwargs)
                if args_meta in sharding_info[op_name]:
                    node_sharding_info = sharding_info[op_name][args_meta]

            invars = [arg.name for arg in node.args if isinstance(arg, Node)]
            outvars = []
            if isinstance(meta_info[node.name], list) or isinstance(meta_info[node.name], tuple):
                for idx, var_meta in enumerate(meta_info[node.name]):
                    if var_meta is not None:
                        outvars.append(
                            UnifyVar(name=f"{node.name}__{idx}",
                                     shape=var_meta["shape"],
                                     dtype=ABSTRACT_DTYPE[var_meta["dtype"]]))
            else:
                outvars.append(
                    UnifyVar(name=node.name,
                             shape=meta_info[node.name]["shape"],
                             dtype=ABSTRACT_DTYPE[meta_info[node.name]["dtype"]]))
            unify_node = UnifyNode(name=op_name,
                                   invars=invars,
                                   outvars=outvars,
                                   sharding_info=node_sharding_info)
            unify_graph.append_op(unify_node)
        elif node.op in ["placeholder", "get_attr"]:
            unify_var = UnifyVar(name=node.name,
                                 shape=meta_info[node.name]["shape"],
                                 dtype=ABSTRACT_DTYPE[meta_info[node.name]["dtype"]])
            unify_graph.append_input(unify_var)
        elif node.op == "output":
            unify_graph.mark_output([arg.name for arg in node.args[0] if arg is not None])

    return unify_graph


def get_torch_sharding_strategy(fx_module: torch.fx.GraphModule, opt_strategy):
    sharding_strategy = {}

    for node in fx_module.graph.nodes:
        if node.op == "call_function":
            op_name = _get_qualified_name(node.target)
            if op_name != "_operator.getitem":
                invars = [
                    UnifyVar(arg.name, None, None) for arg in node.args if isinstance(arg, Node)
                ]
                unique_key = f"{op_name}_{invars}"
                if unique_key in opt_strategy:
                    sharding_strategy[node.name] = [[
                        to_torch_spmd(ii) for ii in i
                    ] for i in opt_strategy[unique_key]['strategy']['invars_sharding']]

    return sharding_strategy
