import os
from collections import defaultdict
from typing import Any, Dict, List, Optional

import ray
from loguru import logger
from torch.distributed import ProcessGroup

from distflow.models.loader import TokenizerModule
from distflow.scheduler.process_group_manager import ProcessGroupManager
from distflow.utils.params import DistflowArguments
from distflow.workers.base_worker import Worker
from distflow.workers.dag import TaskGraph
from distflow.workers.dag.node import NodeRole
from distflow.workers.databuffer import DataProto

from .constants import DAGInitializationError
from .mixins.data_rebalance_mixin import DataRebalanceMixin
from .mixins.execution_mixin import ExecutionMixin
from .mixins.initialization_mixin import InitializationMixin
from .mixins.node_executors_mixin import NodeExecutorsMixin
from .mixins.utilities_mixin import UtilitiesMixin
from .mixins.validation_mixin import ValidationMixin


class DAGWorker(
    InitializationMixin, ExecutionMixin, NodeExecutorsMixin, ValidationMixin, UtilitiesMixin, DataRebalanceMixin, Worker
):
    """
    Orchestrates a Directed Acyclic Graph (DAG) of tasks for distributed training,
    managing the setup, initialization, and workflow for a specific rank.
    """

    def __init__(
        self,
        config: DistflowArguments,
        process_group_manager: ProcessGroupManager,
        taskgraph_mapping: Dict[int, TaskGraph],
        data_buffers: List["ray.actor.ActorHandle"],
    ):
        super().__init__()
        self.config = config
        self.process_group_manager = process_group_manager
        self.taskgraph_mapping = taskgraph_mapping
        self.data_buffers = data_buffers
        self.enable_perf = os.environ.get("distflow_ENABLE_PERF", "0") == "1" or config.dag.enable_perf

        # State attributes
        self.global_steps = 0
        self.total_training_steps = 0
        self.workers: Dict[str, Any] = {}
        self.agent_group_worker: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)
        self.agent_group_process_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)
        self.process_groups: Dict[str, ProcessGroup] = {}
        self.tokenizer_mapping: Dict[str, TokenizerModule] = {}
        self.kl_ctrl_in_reward = None
        self.logger = None
        self.progress_bar = None
        self._rank: int = -1
        self.taskgraph: Optional[TaskGraph] = None
        self.internal_data_cache: Dict[str, DataProto] = {}
        self.agent_critic_worker: Any
        # Finish flag
        self.taskgraph_execute_finished = False

        # async rollout
        self.rollout_mode = "sync"
        self._async_rollout_manager = None
        self.zmq_address = None  # used for async_vllmrollout

        # Add a cache to hold data from an insufficient batch for the next training step.
        # This is the core state-carrying mechanism for dynamic sampling.
        self.sampling_leftover_cache: Optional[DataProto] = None

        # Cumulative timing tracking for dynamic sampling
        # When filtering causes insufficient data, we need multiple rollouts.
        # These variables track the cumulative timing across all rollouts until training completes.
        self.cumulative_timing_raw: Dict[str, float] = {}  # Accumulated timing from incomplete rollouts
        self.cumulative_rollout_count: int = 0              # Number of incomplete rollouts so far
        self.cumulative_start_time: Optional[float] = None  # Wall time when first rollout started

        # multi agent
        self._multi_agent = False
        try:
            self._initialize_worker()
        except (ValueError, TypeError, KeyError, AttributeError, NotImplementedError) as e:
            rank = os.environ.get("RANK", "UNKNOWN")
            logger.error(f"Rank {rank}: Failed to create DAGWorker due to a critical setup error: {e}", exc_info=True)
            raise DAGInitializationError(f"Initialization failed on Rank {rank}: {e}") from e

        self.log_ray_actor_info()
