# Copyright 2024 Bytedance Ltd. and/or its affiliates
from enum import Enum, auto
from functools import wraps
from types import FunctionType
from typing import TYPE_CHECKING, Literal, Union

import ray

from ...protocol import DataProto, DataProtoFuture


if TYPE_CHECKING:
    from .worker_group import WorkerGroup


MAGIC_ATTR = "attrs_3141562937"


class Dispatch(Enum):
    RANK_ZERO = auto()
    ONE_TO_ALL = auto()
    ALL_TO_ALL = auto()
    DP_COMPUTE = auto()
    DP_COMPUTE_PROTO = auto()
    DP_COMPUTE_PROTO_WITH_FUNC = auto()
    DP_COMPUTE_METRIC = auto()


class Execute(Enum):
    ALL = 0
    RANK_ZERO = 1


def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
    splitted_args = []
    for arg in args:
        assert isinstance(arg, (DataProto, DataProtoFuture))
        splitted_args.append(arg.chunk(chunks=chunks))

    splitted_kwargs = {}
    for key, value in kwargs.items():
        assert isinstance(value, (DataProto, DataProtoFuture))
        splitted_kwargs[key] = value.chunk(chunks=chunks)

    return splitted_args, splitted_kwargs


def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
    args = tuple([arg] * worker_group.world_size for arg in args)
    kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
    return args, kwargs


def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
    return args, kwargs


def collect_all_to_all(worker_group: "WorkerGroup", output):
    return output


def _concat_data_proto_or_future(outputs: list[DataProto]) -> DataProto:
    for output in outputs:
        assert type(output) is type(outputs[0])

    output = outputs[0]

    if isinstance(output, DataProto):
        return DataProto.concat(outputs)
    elif isinstance(output, ray.ObjectRef):
        return DataProtoFuture.concat(outputs)
    else:
        raise NotImplementedError


def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
    for arg in args:
        assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size

    for value in kwargs.values():
        assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size

    return args, kwargs


def collect_dp_compute(worker_group: "WorkerGroup", outputs: list[DataProto]) -> list[DataProto]:
    assert len(outputs) == worker_group.world_size
    return outputs


def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
    return splitted_args, splitted_kwargs


def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
    assert type(args[0]) is FunctionType
    splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
    splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
    return splitted_args_with_func, splitted_kwargs


def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: list[DataProto]) -> DataProto:
    for output in outputs:
        assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"

    outputs = collect_dp_compute(worker_group, outputs)
    return _concat_data_proto_or_future(outputs)


def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
    predefined_dispatch_mode_fn = {
        Dispatch.ONE_TO_ALL: {
            "dispatch_fn": dispatch_one_to_all,
            "collect_fn": collect_all_to_all,
        },
        Dispatch.ALL_TO_ALL: {
            "dispatch_fn": dispatch_all_to_all,
            "collect_fn": collect_all_to_all,
        },
        Dispatch.DP_COMPUTE: {
            "dispatch_fn": dispatch_dp_compute,
            "collect_fn": collect_dp_compute,
        },
        Dispatch.DP_COMPUTE_PROTO: {
            "dispatch_fn": dispatch_dp_compute_data_proto,
            "collect_fn": collect_dp_compute_data_proto,
        },
        Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
            "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
            "collect_fn": collect_dp_compute_data_proto,
        },
        Dispatch.DP_COMPUTE_METRIC: {
            "dispatch_fn": dispatch_dp_compute_data_proto,
            "collect_fn": collect_dp_compute,
        },
    }
    return predefined_dispatch_mode_fn[dispatch_mode]


def get_predefined_execute_fn(execute_mode: Execute):
    predefined_execute_mode_fn = {
        Execute.ALL: {"execute_fn_name": "execute_all"},
        Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
    }
    return predefined_execute_mode_fn[execute_mode]


def _check_dispatch_mode(dispatch_mode: Union[Dispatch, dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
    assert isinstance(dispatch_mode, (Dispatch, dict)), (
        f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
    )
    if isinstance(dispatch_mode, dict):
        necessary_keys = ["dispatch_fn", "collect_fn"]
        for key in necessary_keys:
            assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"


def _check_execute_mode(execute_mode: Execute):
    assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"


def _materialize_futures(*args, **kwargs):
    new_args = []
    for arg in args:
        if isinstance(arg, DataProtoFuture):
            arg = arg.get()
        new_args.append(arg)

    for key, value in kwargs.items():
        if isinstance(value, DataProtoFuture):
            kwargs[key] = value.get()

    new_args = tuple(new_args)
    return new_args, kwargs


def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
    _check_dispatch_mode(dispatch_mode=dispatch_mode)
    _check_execute_mode(execute_mode=execute_mode)

    def decorator(func):
        @wraps(func)
        def inner(*args, **kwargs):
            if materialize_futures:
                args, kwargs = _materialize_futures(*args, **kwargs)
            return func(*args, **kwargs)

        attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
        setattr(inner, MAGIC_ATTR, attrs)
        return inner

    return decorator
