import os
import argparse
from functools import partial
from ares.distributed.run import run, RoleInfo
from ares.utils.logger import logger
from ares.commons.constants import RoleMappingConstants


def parser_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config-py-file",
        type=str,
        required=True,
        help="The configuration path. Absolute paths are recommended",
    )
    parser.add_argument(
        "-record",
        "--role-record-seperately",
        dest="role_record",
        action="store_true",
        default=False,
        help="If set, role tensorboard records will be stored separately.",
    )
    parser.add_argument(
        "--qkv-first",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--rpc-timeout",
        type=float,
        default=7200.0,
    )
    parser.add_argument(
        "--pystack-timeout",
        type=int,
        default=60,
    )
    parser.add_argument(
        "--node-0-master",
        action="store_true",
        default=False,
    )

    args = parser.parse_args()
    return args


def run_trainer(config):
    from ares.utils.utils import seed_everything
    from ares.tools import experiment_tracking

    seed_everything(config.train.seed)
    if "place_holder" in config.roles:
        del config.roles["place_holder"]

    for role in config.roles.values():
        if role.seed is None:
            role.seed = config.train.seed

    # Initialize trainer based on configuration
    if config.train.name == "DPOTrainer":
        from ares.trainers.dpo_trainer import DPOTrainer
        trainer = DPOTrainer(config)
    elif config.train.name == "HybridDPOTrainer":
        from ares.trainers.hybrid_dpo_trainer import HybridDPOTrainer
        trainer = HybridDPOTrainer(config)
    elif config.train.name == "RMTrainer":
        from ares.trainers.rm_trainer import RMTrainer
        trainer = RMTrainer(config)
    elif config.train.name == "PRMTrainer":
        from ares.trainers.prm_trainer import PRMTrainer
        trainer = PRMTrainer(config)
    elif config.train.name == "PBRLCompositeTrainer":
        from ares.trainers.pbrl_composite_trainer import PBRLCompositeTrainer
        trainer = PBRLCompositeTrainer(config)
    elif config.train.name == "GRPOTrainer":
        from ares.trainers.grpo_trainer import GRPOTrainer
        trainer = GRPOTrainer(config)
    elif config.train.name == "OnlineRSTrainer":
        from ares.trainers.online_rs_trainer import OnlineRSTrainer
        trainer = OnlineRSTrainer(config)
    elif config.train.name == "OnlineDPOTrainer":
        from ares.trainers.online_dpo_trainer import OnlineDPOTrainer
        trainer = OnlineDPOTrainer(config)
    elif config.train.name == "TestGenerator":
        from ares.trainers.test_generator import TestGenerator
        trainer = TestGenerator(config)
    elif config.train.name == "TestStreamingGenerator":
        from ares.trainers.test_streaming_generator import TestStreamingGenerator
        trainer = TestStreamingGenerator(config)
    elif config.train.name == "PPOTrainer":
        from ares.trainers.ppo_trainer import PPOTrainer
        trainer = PPOTrainer(config)
    elif config.train.name == "MCTSTrainer":
        from ares.trainers.mcts_trainer import MCTSTrainer
        trainer = MCTSTrainer(config)
    elif config.train.name == "RMEvaluator":
        from ares.trainers.rm_evaluator import RMEvaluator
        trainer = RMEvaluator(config)
    elif config.train.name == "StreamingPPOTrainer":
        from ares.trainers.streaming_ppo_trainer_v2 import StreamingPPOTrainer
        trainer = StreamingPPOTrainer(config)
    elif config.train.name == "TestLoadBalancing":
        from ares.trainers.test_load_balancing import TestLoadBalancing
        trainer = TestLoadBalancing(config)
    else:
        raise ValueError(f"Unrecognized trainer type: {config.train.name}")

    experiment_tracking.track_trainer_config(trainer)
    trainer.learn()
    logger.info(f"{config.train.name} learn done")


def run_inferencer(config):
    from ares.utils.utils import seed_everything
    from ares.tools import experiment_tracking

    seed_everything(config.infer.seed)
    if "place_holder" in config.roles:
        del config.roles["place_holder"]

    for role in config.roles.values():
        if role.seed is None:
            role.seed = config.infer.seed

    # Initialize inferencer based on configuration
    if config.infer.server is not None:
        if config.infer.server == "InferenceServer":
            from ares.inferencers.inference_server import InferenceServer
            inferencer = InferenceServer(config)
        else:
            raise ValueError(f"Unrecognized server type: {config.infer.server}")
        experiment_tracking.track_trainer_config(inferencer._inferencer)
    else:
        if config.infer.name == "HybridGenerator":
            from ares.inferencers.hybrid_generator import HybridGenerator
            inferencer = HybridGenerator(config)
        else:
            raise ValueError(f"Unrecognized infer type: {config.infer.name}")

        experiment_tracking.track_trainer_config(inferencer)
    inferencer.run()
    logger.info(f"{config.infer.name} infer done")


def main():
    from ares.commons.global_vars import set_pystack_timeout
    from ares.configs import get_config

    args = parser_args()
    config = get_config(workspace=os.getcwd(), args=args)

    # Setup roles for distributed execution
    roles = []
    existing_role_names = set()
    for role_name, role_config in config.roles.items():
        role_name = (
            RoleMappingConstants.get_colocated_role_name(role_config.colocated_roles)
            if role_config.colocated_roles
            else role_name
        )
        existing_role_names = set([role.name for role in roles])
        if role_name not in existing_role_names:
            roles.append(
                RoleInfo(
                    name=role_name,
                    num_gpus=role_config.num_gpus,
                    dist_timeout_minutes=role_config.dist_timeout_minutes,
                    colocated_roles=role_config.colocated_roles,
                )
            )

    worker_extra_args = ""
    if hasattr(config, "use_actor_generate") and config.use_actor_generate:
        worker_extra_args.append("--disable-device-map")

    # Determine execution mode and setup entry function
    if hasattr(config, "train"):
        config_runner = config.train
        master_entry_fn = partial(run_trainer, config)
    elif hasattr(config, "infer"):
        config_runner = config.infer
        master_entry_fn = partial(run_inferencer, config)
    else:
        raise ValueError("config must have train or infer field")

    log_dir = config_runner.log_dir
    if args.node_0_master:
        config_runner.experiment_tracking = False
        logger.info(f"Experiment tracking is disabled for node_0_master is True")
    
    set_pystack_timeout(args.pystack_timeout)
    run(
        roles=roles,
        log_dir=log_dir,
        master_entry_fn=master_entry_fn,
        worker_extra_args=worker_extra_args,
        rpc_timeout=args.rpc_timeout,
        node_0_master=args.node_0_master,
    )


if __name__ == "__main__":
    main()