import os
import time

import hydra
import ray
from loguru import logger
from omegaconf import DictConfig

from distflow.scheduler.enums import AdvantageEstimator, WorkflowType
from distflow.scheduler.graph_updater import display_node_config, update_task_graph_node_configs
from distflow.scheduler.launch import RayTrainer
from distflow.scheduler.process_group_manager import ProcessGroupManager, log_process_group_manager_details
from distflow.scheduler.task_scheduler import TaskScheduler, log_schedule_assignments
from distflow.utils.logger.logging_utils import set_basic_config
from distflow.utils.params import DistflowArguments, log_dict_formatted, parse_config
from distflow.workers.dag import DAGConfigLoader
from distflow.workers.databuffer import init_data_buffer

# --- Constants ---
RAY_RUNTIME_ENV_VARS = {
    "TOKENIZERS_PARALLELISM": "true",
    "NCCL_DEBUG": "WARN",
    "VLLM_LOGGING_LEVEL": "WARN",
}

# The main runner is an orchestrator, not a heavy workload.
# Assigning it a full CPU is often wasteful. A fractional CPU is more efficient.
MAIN_RUNNER_CPU_RESERVATION = 5


def determine_workflow_config(self, distflow_args: DistflowArguments) -> str:
    current_dir = os.path.dirname(os.path.abspath(__file__))
    workflow = distflow_args.algorithm.workflow_type

    if workflow == WorkflowType.EMBODIED:
        # Embodied AI workflows
        if distflow_args.algorithm.adv_estimator == AdvantageEstimator.GAE:
            return os.path.join(current_dir, "config/workflow_embodied_ppo.yaml")
        elif distflow_args.algorithm.adv_estimator == AdvantageEstimator.GRPO:
            return os.path.join(current_dir, "config/workflow_embodied_srpo.yaml")
        else:
            raise ValueError(
                f"Unsupported adv_estimator '{distflow_args.algorithm.adv_estimator}' for Embodied AI. "
                f"Use 'gae' for PPO or 'grpo' for GRPO."
            )
    elif workflow == WorkflowType.DAPO:
        return os.path.join(current_dir, "config/workflow_dapo.yaml")
    elif workflow == WorkflowType.DEFAULT:
        if distflow_args.algorithm.adv_estimator == AdvantageEstimator.GAE:
            return os.path.join(current_dir, "config/workflow_ppo.yaml")
        else:  # For GRPO, GSPO, etc.
            return os.path.join(current_dir, "config/workflow_grpo.yaml")

    else:
        raise ValueError(f"Unknown workflow_type: '{workflow}'")


def setup_embodied_task_manifest(distflow_args: DistflowArguments) -> DistflowArguments:
    """
    Generate task manifests for embodied AI training runs.
    
    For VLA/Embodied runs, this function generates fresh task manifests and updates
    the configuration to point to the generated files. The `data.train_files` config
    is used as the *output path* for these manifests.
    
    Args:
        distflow_args: A DistflowArguments object containing all parsed configurations.
    
    Returns:
        The modified DistflowArguments object with updated train_files and val_files.
    
    Raises:
        ValueError: If embodied training is detected but data.train_files is not specified.
    """
    from loguru import logger
    # Detect embodied run by checking if model_type is "embodied"
    is_embodied_model = (
        hasattr(distflow_args.actor_rollout_ref, 'model') 
        and hasattr(distflow_args.actor_rollout_ref.model, 'model_type')
        and distflow_args.actor_rollout_ref.model.model_type == "embodied"
    )
    
    if is_embodied_model:
        embodied_env_args = distflow_args.actor_rollout_ref.embodied.env

        # For embodied training, `data.train_files` must be specified to indicate the output path.
        if not distflow_args.data.train_files:
            raise ValueError(
                "For embodied training, `data.train_files` must be specified in the config. "
                "It is used as the output path for the generated task manifest."
            )

        logger.info("Embodied AI run detected. Generating task manifest...")
        from distflow.dataloader.embodied_preprocess import prepare_libero_train_valid_datasets

        # The output directory is the parent directory of the first train file path.
        output_dir = os.path.dirname(distflow_args.data.train_files[0])

        # Generate the train and validation manifests.
        # The function will create `train.parquet` and `validate.parquet` inside output_dir.
        train_file_path, valid_file_path = prepare_libero_train_valid_datasets(
            task_suite_name=embodied_env_args.env_name,
            num_trials_per_task=embodied_env_args.num_trials_per_task,
            dataset_dir=output_dir,
        )

        # After generation, update the config to point to the exact generated files.
        distflow_args.data.train_files = [str(train_file_path)]
        distflow_args.data.val_files = [str(valid_file_path)]
        logger.success(f"Task manifests generated and configured at: {output_dir}")
    
    return distflow_args


def get_databuffer_shard_number(distflow_args: DistflowArguments) -> int:
    assert distflow_args.data.train_batch_size % distflow_args.trainer.nnodes == 0, (
        f"Config Error: train_batch_size ({distflow_args.data.train_batch_size}) must be divisible by nnodes ("
        f"{distflow_args.trainer.nnodes}). Please adjust your configuration."
    )

    batch_size_per_node = distflow_args.data.train_batch_size // distflow_args.trainer.nnodes
    intermediate_value = batch_size_per_node * distflow_args.actor_rollout_ref.rollout.n
    SHARDING_FACTOR = 8
    assert intermediate_value % SHARDING_FACTOR == 0, (
        f"Config Error: The result of '(train_batch_size / nnodes) * rollout_n' ({intermediate_value}) must be "
        f"divisible by {SHARDING_FACTOR}. Please adjust your configuration."
    )
    databuffer_number = (
        (distflow_args.data.train_batch_size // distflow_args.trainer.nnodes) * distflow_args.actor_rollout_ref.rollout.n
    ) // 8
    databuffer_number = min(databuffer_number, distflow_args.trainer.nnodes)
    return databuffer_number


@ray.remote(num_cpus=MAIN_RUNNER_CPU_RESERVATION)
class MainRunner:
    """
    A Ray actor responsible for orchestrating the entire RL training workflow.

    This actor handles loading configurations, scheduling task graphs, initializing
    process groups, and launching the distributed Ray trainers. Isolating this
    orchestration logic in a dedicated actor ensures the main process remains clean
    and that the setup process is managed within the Ray cluster.
    """

    def run(self, distflow_args: DistflowArguments) -> None:
        """
        Executes the main training workflow.

        Args:
            distflow_args: A DistflowArguments object containing all parsed configurations.
        """
        set_basic_config()
        from loguru import logger

        logger.info("MainRunner started. Beginning workflow setup...")
        start_time = time.time()

        # Setup embodied task manifest if needed
        distflow_args = setup_embodied_task_manifest(distflow_args)

        # 1. Init DataBuffer
        logger.info(f"Init DataBuffer with sharding number: {distflow_args.trainer.nnodes}")
        databuffer_number = get_databuffer_shard_number(distflow_args)
        data_buffer_handlers = init_data_buffer(databuffer_number)

        # 2. Load and configure the workerflow task graph (DAG)
        if distflow_args.dag.workflow_path is None:
            # If no workerflow path is provided, determine the default workflow config
            workflow_path = determine_workflow_config(self, distflow_args)
            logger.info(
                f"No workerflow path provided. Using {workflow_path} determined by adv_estimator: "
                f"{distflow_args.algorithm.adv_estimator}"
            )
        else:
            workflow_path = distflow_args.dag.workflow_path
        logger.info(f"Loading workerflow from: {distflow_args.dag.workflow_path}")
        if distflow_args.algorithm.adv_estimator == AdvantageEstimator.CPGD:
            distflow_args.actor_rollout_ref.actor.use_cpgd_loss = True
        workerflow_taskgraph = DAGConfigLoader.load_from_file(workflow_path)
        update_task_graph_node_configs(workerflow_taskgraph, distflow_args)
        display_node_config(workerflow_taskgraph)

        # 3. Schedule the task graph across available resources
        logger.info("Scheduling tasks across nodes and GPUs...")
        total_workers = distflow_args.trainer.nnodes * distflow_args.trainer.n_gpus_per_node
        task_scheduler = TaskScheduler(distflow_args.trainer.nnodes, distflow_args.trainer.n_gpus_per_node)
        rank_taskgraph_mapping = task_scheduler.schedule_and_assign_tasks([workerflow_taskgraph])
        log_schedule_assignments(rank_taskgraph_mapping, total_workers)
        unique_graphs_map = task_scheduler.get_unique_assigned_task_graphs()

        # 4. Create and configure process groups for communication
        logger.info("Initializing process groups for distributed communication...")
        process_group_manager = ProcessGroupManager(total_workers, rank_taskgraph_mapping)
        log_process_group_manager_details(process_group_manager, log_level="debug")
        # set process_group info into env for inference_actor
        inference_process_group = []
        inference_groups = process_group_manager.node_type_process_group_mapping["MODEL_INFERENCE"]
        for group_name in inference_groups:
            inference_process_group.append(process_group_manager.process_group_spec[group_name])
        os.environ["DGA_PROCESS_GROUP"] = str(inference_process_group)
        # 6. Initialize the main trainer
        logger.info("Initializing RayTrainer...")
        trainer = RayTrainer(
            config=distflow_args,
            process_group_manager=process_group_manager,
            rank_taskgraph_mapping=rank_taskgraph_mapping,
            unique_graphs_map=unique_graphs_map,
            data_buffer_handles=data_buffer_handlers,  # Placeholder for DataBuffer
            device_name=distflow_args.trainer.device,
        )

        # 7. Initialize and start DAGWorkers
        logger.info("Initializing and starting DAG workers...")
        trainer.init_workers()
        trainer.start_workers()

        setup_duration = time.time() - start_time
        logger.info(f"Workflow setup and worker launch complete. Time cost: {setup_duration:.2f}s")


@hydra.main(config_path="config", config_name="ppo_dag_trainer", version_base=None)
def main(distflow_config: DictConfig) -> None:
    """
    Main entry point for launching the PPO DAG training job.

    This function initializes Ray, parses configurations using Hydra, and
    starts the MainRunner actor to orchestrate the distributed training workflow.

    Args:
        distflow_config: The configuration object provided by Hydra.
    """
    start_time = time.time()

    # Initialize Ray cluster if not already running
    if not ray.is_initialized():
        logger.info("Initializing local Ray cluster...")
        ray.init(runtime_env={"env_vars": RAY_RUNTIME_ENV_VARS}, num_cpus=distflow_config.ray_init.num_cpus)
    logger.success(f"Ray is initialized. Time cost: {(time.time() - start_time) * 1000:.2f} ms")

    # Parse the complete configuration into a structured object
    distflow_args = parse_config(distflow_config)
    log_dict_formatted(distflow_args.to_dict(), "DistflowArguments")

    # Launch the main orchestration actor and wait for it to complete.
    logger.info("Starting MainRunner actor to orchestrate the job.")
    runner = MainRunner.remote()
    # This is a blocking call that waits for the remote `run` method to finish.
    ray.get(runner.run.remote(distflow_args))

    logger.success("MainRunner has completed its execution. Shutting down.")


if __name__ == "__main__":
    main()
