import logging
import signal
import threading
import time
from typing import Any, Callable
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool:
    def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None:
        if process_on_nodes is None:
            process_on_nodes = []
        self._store = process_on_nodes
        self.max_colocate_count = max_colocate_count
        self.n_gpus_per_node = n_gpus_per_node  
    def add_node(self, process_count):
        self._store.append(process_count)
    @property
    def world_size(self):
        return sum(self._store)
    def __call__(self) -> Any:
        return self._store
    @property
    def store(self):
        return self._store
    def local_world_size_list(self) -> list[int]:
        nested_local_world_size_list = [
            [local_world_size for _ in range(local_world_size)] for local_world_size in self._store
        ]
        return [item for row in nested_local_world_size_list for item in row]
    def local_rank_list(self) -> list[int]:
        nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
        return [item for row in nested_local_rank_list for item in row]
class ClassWithInitArgs:
    def __init__(self, cls, *args, **kwargs) -> None:
        self.cls = cls
        self.args = args
        self.kwargs = kwargs
        self.fused_worker_used = False
    def __call__(self) -> Any:
        return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None:
    import time
    while True:
        for worker in workers:
            if not is_alive(worker):
                logging.warning(f"worker {worker} is not alive sending signal to main thread")
                signal.raise_signal(signal.SIGABRT)
        time.sleep(gap_time)
class WorkerGroup:
    fused_worker_execute_fn_name = "_fuw_execute"
    def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
        self._is_init_with_detached_workers = resource_pool is None
        self.fused_worker_used = False
        if resource_pool is not None:
            self._procecss_dispatch_config = resource_pool()
        else:
            self._procecss_dispatch_config = None
        self._workers = []
        self._worker_names = []
        self._dispatch_info = {}
        self._collect_info = {}
        self._master_addr = None
        self._master_port = None
        self._checker_thread: threading.Thread = None
    def _is_worker_alive(self, worker):
        raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
    def _block_until_all_workers_alive(self) -> None:
        while True:
            all_state = [self._is_worker_alive(worker) for worker in self._workers]
            if False in all_state:
                time.sleep(1)
            else:
                break
    def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
        self._block_until_all_workers_alive()
        self._checker_thread = threading.Thread(
            target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)
        )
        self._checker_thread.start()
    @property
    def world_size(self):
        return len(self._workers)
    def _bind_worker_method(self, user_defined_cls, func_generator):
        method_names = []
        for method_name in dir(user_defined_cls):
            try:
                method = getattr(user_defined_cls, method_name)
                assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
            except Exception:
                continue
            if hasattr(method, MAGIC_ATTR):
                attribute = getattr(method, MAGIC_ATTR)
                assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}"
                assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
                dispatch_mode = attribute["dispatch_mode"]
                execute_mode = attribute["execute_mode"]
                blocking = attribute["blocking"]
                if isinstance(dispatch_mode, Dispatch):
                    fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
                    dispatch_fn = fn["dispatch_fn"]
                    collect_fn = fn["collect_fn"]
                else:
                    assert isinstance(dispatch_mode, dict)
                    assert "dispatch_fn" in dispatch_mode
                    assert "collect_fn" in dispatch_mode
                    dispatch_fn = dispatch_mode["dispatch_fn"]
                    collect_fn = dispatch_mode["collect_fn"]
                execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
                wg_execute_fn_name = execute_mode["execute_fn_name"]
                try:
                    execute_fn = getattr(self, wg_execute_fn_name)
                    assert callable(execute_fn), "execute_fn must be callable"
                except Exception:
                    print(f"execute_fn {wg_execute_fn_name} is invalid")
                    raise
                func = func_generator(
                    self,
                    method_name,
                    dispatch_fn=dispatch_fn,
                    collect_fn=collect_fn,
                    execute_fn=execute_fn,
                    blocking=blocking,
                )
                try:
                    setattr(self, method_name, func)
                    method_names.append(method_name)
                except Exception as e:
                    raise ValueError(f"Fail to set method_name {method_name}") from e
        return method_names