import copy
from typing import List
from functools import reduce

from meshflow.unifyshard.combination import ReduceOp
from meshflow.platform import get_backend

DEVICE_MESH_1D = -1


class SPMD:

    REPLICATE = "REPLICATE"
    SHARD = "SHARD"
    PARTIAL = "PARTIAL"

    def __init__(self, state, args=None) -> None:
        self.state = state
        self.args = args

    def __str__(self) -> str:
        str_ = f"{self.state}"
        if self.args:
            str_ += f"({self.args})"
        return str_

    def __repr__(self) -> str:
        return self.__str__()


def get_sharding_strategy(sharding_anns, shard_dim_id):
    spmd_strategy = []

    for tensor_ann in sharding_anns.annotation:
        shard_dim_ = None
        for dim_idx, dim_ann in enumerate(tensor_ann):
            if dim_ann.shard_dim_id == shard_dim_id:
                shard_dim_ = dim_idx
        if shard_dim_ is not None:
            spmd_strategy.append(SPMD(SPMD.SHARD, {"dim": shard_dim_}))
        else:
            spmd_strategy.append(SPMD(SPMD.REPLICATE))

    return spmd_strategy


def combination_to_sharding_strategy(comm_anns, all_replicate=False):
    # example of comm_anns:
    # functools.partial(<function CombinationFunc.gather at 0x7fab788efd30>, dim=0)

    spmd_strategy = []

    if not (isinstance(comm_anns, list) or isinstance(comm_anns, tuple)):
        comm_anns = [comm_anns]
    for comm_ann in comm_anns:
        func_name = comm_ann.func.__name__
        if all_replicate or func_name == "identity":
            spmd_strategy.append(SPMD(SPMD.REPLICATE))
        elif func_name == "gather":
            spmd_strategy.append(SPMD(SPMD.SHARD, comm_ann.keywords))
        elif func_name == "reduce":
            spmd_strategy.append(SPMD(SPMD.PARTIAL, comm_ann.keywords))

    return spmd_strategy


class UnifyVar:

    def __init__(self, name, shape, dtype) -> None:
        self.name = name
        self.shape = shape
        self.dtype = dtype

        self.dtype2byte = {
            "float32": 4,
            "float16": 2,
            "bool": 0.125,
            "int32": 4,
            "int64": 8,
            "uint32": 4,
            "uint8": 1,
            "complex64": 16,
        }

    def get_var_size(self):
        num_ele = reduce((lambda x, y: x * y), self.shape)
        return self.dtype2byte[self.dtype] * num_ele

    def __str__(self, details=False) -> str:
        if details:
            return f"{self.name}({self.shape}, {self.dtype})"
        return self.name.__str__()

    def __repr__(self) -> str:
        return self.__str__()


class UnifyNode:

    def __init__(self, name, invars: List[UnifyVar], outvars: List[UnifyVar],
                 sharding_info) -> None:
        self.name = name
        self.invars = invars
        self.outvars = outvars

        self.unique_key_ = None

        self.sharding_info = sharding_info

    def unique_key(self):
        if self.unique_key_ is None:
            self.unique_key_ = f"{self.name}_{self.invars}"
        return self.unique_key_

    def get_strategy(self):
        strategy_list = []

        sharding_anns = self.sharding_info['sharding_ann']
        comm_anns = self.sharding_info['combination_ann']

        # some ops like torch.ops.aten.scalar_tensor.default have no invars
        if len(comm_anns) == 0:
            return strategy_list

        for comm_ann in comm_anns:
            invars_sharding = get_sharding_strategy(sharding_anns, comm_ann)
            outvars_sharding = combination_to_sharding_strategy(comm_anns[comm_ann])
            strategy_list.append({
                "invars_sharding": invars_sharding,
                "outvars_sharding": outvars_sharding
            })

        # replicate_strategy = {
        #     "invars_sharding":
        #     get_sharding_strategy(sharding_anns, None),
        #     "outvars_sharding":
        #     combination_to_sharding_strategy(comm_anns[comm_ann], all_replicate=True)
        # }
        # strategy_list.append(replicate_strategy)
        # (FIXME) modify stratgy here, need remove
        if "convolution" not in self.name and "aten.mm" not in self.name and "dot" not in self.name:
            replicate_strategy = {
                "invars_sharding":
                get_sharding_strategy(sharding_anns, None),
                "outvars_sharding":
                combination_to_sharding_strategy(comm_anns[comm_ann], all_replicate=True)
            }
            strategy_list.append(replicate_strategy)

            if "batch_norm" in self.name:
                dp_strategy = {
                    "invars_sharding":
                    get_sharding_strategy(sharding_anns, None),
                    "outvars_sharding":
                    combination_to_sharding_strategy(comm_anns[comm_ann], all_replicate=True)
                }
                dp_strategy["invars_sharding"][0] = SPMD(SPMD.SHARD, {'dim': 0})
                dp_strategy["outvars_sharding"][0] = SPMD(SPMD.SHARD, {'dim': 0})
                if "backward" in self.name:
                    dp_strategy["invars_sharding"][1] = SPMD(SPMD.SHARD, {'dim': 0})
                    dp_strategy["outvars_sharding"][1] = SPMD(SPMD.PARTIAL, {'ops': ReduceOp.SUM})
                    dp_strategy["outvars_sharding"][2] = SPMD(SPMD.PARTIAL, {'ops': ReduceOp.SUM})
                strategy_list = [dp_strategy, replicate_strategy]
        
        if "convolution.default" in self.name:
            del strategy_list[1]
        if "convolution_backward.default" in self.name:
            del_idx = None
            for idx, s in enumerate(strategy_list):
                if s["invars_sharding"][-1] == SPMD(SPMD.SHARD, {'dim': 1}):
                    del_idx = idx
                    break
            if del_idx is not None:
                del strategy_list[del_idx]

        strategy_list_2d = []
        if DEVICE_MESH_1D == -1:
            for idx1, s1 in enumerate(strategy_list):
                for idx2, s2 in enumerate(strategy_list):
                    # [Shard(i), Shard(i)] is not support for pytorch dtensor runtime
                    if get_backend() == "torch" and idx1 == idx2:
                        continue
                    strategy_list_2d.append({
                        "invars_sharding":
                        [[i, j] for i, j in zip(s1["invars_sharding"], s2["invars_sharding"])],
                        "outvars_sharding":
                        [[i, j] for i, j in zip(s1["outvars_sharding"], s2["outvars_sharding"])],
                    })
        else:
            for s in strategy_list:
                if DEVICE_MESH_1D == 0:
                    strategy_list_2d.append({
                        "invars_sharding":
                        [[SPMD(SPMD.REPLICATE), i] for i in s["invars_sharding"]],
                        "outvars_sharding":
                        [[SPMD(SPMD.REPLICATE), i] for i in s["outvars_sharding"]],
                    })
                elif DEVICE_MESH_1D == 1:
                    strategy_list_2d.append({
                        "invars_sharding":
                        [[i, SPMD(SPMD.REPLICATE)] for i in s["invars_sharding"]],
                        "outvars_sharding":
                        [[i, SPMD(SPMD.REPLICATE)] for i in s["outvars_sharding"]],
                    })
                else:
                    exit(-1)

        return strategy_list_2d

    def __str__(self, details=False) -> str:
        if details:
            return_str = f"{self.name}\n"
            return_str += "invars: " + ",".join([var.__str__(True) for var in self.invars]) + "\n"
            return_str += "outvars: " + ",".join([var.__str__(True)
                                                  for var in self.outvars]) + "\n"
            return_str += f"sharding_info: {self.sharding_info}"
            return return_str
        return self.name.__str__()

    def __repr__(self) -> str:
        return self.__str__()


class UnifyGraph:

    def __init__(self, ori_struct) -> None:
        self.ori_struct = ori_struct
        self.input_list = []
        self.op_list = []
        self.output_list = []

        self.name_to_var = {}

    def append_input(self, invars: UnifyVar) -> None:
        self.input_list.append(invars)
        self.name_to_var[invars.name] = invars

    def append_op(self, unify_node: UnifyNode) -> None:
        # update args in unify_node
        unify_node.invars = [self.name_to_var[name] for name in unify_node.invars]
        self.op_list.append(unify_node)
        for var in unify_node.outvars:
            self.name_to_var[var.name] = var

    def mark_output(self, output_list: List[str]) -> None:
        output_list = [self.name_to_var[name] for name in output_list]
        self.output_list = output_list

    def rename_var(self, old_name, new_name):
        self.name_to_var[old_name].name = new_name
        self.name_to_var[new_name] = self.name_to_var.pop(old_name, None)

    def __str__(self) -> str:
        return_str = f"=====================\n[UnifyIR]\n\ninput_list: {self.input_list.__str__()}\n\n"
        for op in self.op_list:
            return_str += f"{op.outvars} <--- [{op.name}] --- {op.invars}\n"
        return_str += f"\noutput_list: {self.output_list.__str__()}\n=====================\n"
        return return_str

    def __repr__(self) -> str:
        return self.__str__()

    def liveness(self):
        liveness_set = set([var.name for var in self.output_list])

        liveness_list = []

        for op in reversed(self.op_list):
            for var in op.invars:
                liveness_set.add(var.name)
            for var in op.outvars:
                liveness_set.add(var.name)

            liveness_list.insert(
                0, copy.deepcopy(liveness_set.union(set([var.name for var in self.input_list]))))

            for var in op.outvars:
                liveness_set.remove(var.name)

        return liveness_list